embed-tree 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embed_tree/__init__.py +121 -0
- embed_tree/cache/__init__.py +10 -0
- embed_tree/cache/json.py +7 -0
- embed_tree/cache/model.py +13 -0
- embed_tree/cache/sqlalchemy.py +7 -0
- embed_tree/config.py +100 -0
- embed_tree/embedders/__init__.py +7 -0
- embed_tree/embedders/huggingface.py +69 -0
- embed_tree/embedders/model.py +27 -0
- embed_tree/labelers/__init__.py +8 -0
- embed_tree/labelers/function.py +25 -0
- embed_tree/labelers/llm.py +26 -0
- embed_tree/labelers/model.py +38 -0
- embed_tree/loaders/__init__.py +17 -0
- embed_tree/loaders/filesystem.py +83 -0
- embed_tree/loaders/json.py +49 -0
- embed_tree/loaders/model.py +20 -0
- embed_tree/loaders/sqlalchemy.py +91 -0
- embed_tree/loaders/sqlalchemy_content.py +63 -0
- embed_tree/loaders/sqlite.py +21 -0
- embed_tree/persisters/__init__.py +15 -0
- embed_tree/persisters/filesystem.py +293 -0
- embed_tree/persisters/json.py +29 -0
- embed_tree/persisters/model.py +23 -0
- embed_tree/persisters/sqlalchemy.py +76 -0
- embed_tree/projectors/__init__.py +7 -0
- embed_tree/projectors/model.py +39 -0
- embed_tree/projectors/pca.py +57 -0
- embed_tree/providers/__init__.py +20 -0
- embed_tree/providers/base.py +104 -0
- embed_tree/providers/fake.py +26 -0
- embed_tree/providers/local.py +44 -0
- embed_tree/providers/openai.py +49 -0
- embed_tree/reconcilers/__init__.py +6 -0
- embed_tree/reconcilers/default.py +65 -0
- embed_tree/reconcilers/model.py +25 -0
- embed_tree/reducers.py +194 -0
- embed_tree/representation/__init__.py +27 -0
- embed_tree/representation/default.py +59 -0
- embed_tree/representation/model.py +87 -0
- embed_tree/store.py +5 -0
- embed_tree/stores/__init__.py +8 -0
- embed_tree/stores/file.py +32 -0
- embed_tree/stores/model.py +25 -0
- embed_tree/stores/null.py +16 -0
- embed_tree/taggers.py +132 -0
- embed_tree/tree.py +691 -0
- embed_tree-0.0.6.dist-info/METADATA +182 -0
- embed_tree-0.0.6.dist-info/RECORD +51 -0
- embed_tree-0.0.6.dist-info/WHEEL +5 -0
- embed_tree-0.0.6.dist-info/top_level.txt +1 -0
embed_tree/__init__.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""embed-tree: an incremental hierarchical clustering tree over embeddings.
|
|
2
|
+
|
|
3
|
+
See DESIGN.md for the full design. Minimal usage:
|
|
4
|
+
|
|
5
|
+
from embed_tree import EmbedTree, TreeConfig, FileTreeStore
|
|
6
|
+
|
|
7
|
+
tree = EmbedTree(embedder=my_embed_fn, store=FileTreeStore("./tree.json"))
|
|
8
|
+
tree.add("some content")
|
|
9
|
+
hits = tree.query("similar content", k=5)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .config import LLMConfig, RebalanceConfig, TreeConfig
|
|
13
|
+
from .embedders import HuggingFaceTextEmbedder, TextEmbedder, embed_texts
|
|
14
|
+
from .labelers import FunctionLabeler, LabelCandidate, Labeler, LabelRequest, LLMLabeler
|
|
15
|
+
from .loaders import (
|
|
16
|
+
FileSystemTreeLoader,
|
|
17
|
+
JsonTreeLoader,
|
|
18
|
+
SQLAlchemyContentLoader,
|
|
19
|
+
SQLAlchemyTreeLoader,
|
|
20
|
+
SQLiteTreeLoader,
|
|
21
|
+
TreeLoader,
|
|
22
|
+
)
|
|
23
|
+
from .persisters import (
|
|
24
|
+
FileSystemTreePersister,
|
|
25
|
+
FolderTreePersister,
|
|
26
|
+
JsonTreePersister,
|
|
27
|
+
MaterializedTreeState,
|
|
28
|
+
SQLAlchemyTreePersister,
|
|
29
|
+
TreePersister,
|
|
30
|
+
)
|
|
31
|
+
from .projectors import PCAConfig, PCAProjector, VectorProjector
|
|
32
|
+
from .providers import (
|
|
33
|
+
EmbeddingProvider,
|
|
34
|
+
FakeEmbeddingProvider,
|
|
35
|
+
OpenAIEmbeddingProvider,
|
|
36
|
+
SentenceTransformerProvider,
|
|
37
|
+
)
|
|
38
|
+
from .reducers import (
|
|
39
|
+
FreezePCAReducer,
|
|
40
|
+
IdentityReducer,
|
|
41
|
+
IncrementalPCAReducer,
|
|
42
|
+
Reducer,
|
|
43
|
+
)
|
|
44
|
+
from .reconcilers import DefaultTreeReconciler, TreeReconciler
|
|
45
|
+
from .representation import (
|
|
46
|
+
ContentNode,
|
|
47
|
+
DefaultTreeRepresentation,
|
|
48
|
+
KeyNode,
|
|
49
|
+
NodeAggregate,
|
|
50
|
+
NodeEmbedding,
|
|
51
|
+
NodeId,
|
|
52
|
+
PartialTree,
|
|
53
|
+
TreeEdge,
|
|
54
|
+
VectorData,
|
|
55
|
+
partial_tree_from_dict,
|
|
56
|
+
partial_tree_to_dict,
|
|
57
|
+
)
|
|
58
|
+
from .store import FileTreeStore, NullTreeStore, TreeState, TreeStore
|
|
59
|
+
from .taggers import KeywordTagger, LLMTagger, Tagger, make_tagger
|
|
60
|
+
from .tree import EmbedTree, Item, Node
|
|
61
|
+
|
|
62
|
+
__all__ = [
|
|
63
|
+
"EmbedTree",
|
|
64
|
+
"Item",
|
|
65
|
+
"Node",
|
|
66
|
+
"TreeConfig",
|
|
67
|
+
"RebalanceConfig",
|
|
68
|
+
"LLMConfig",
|
|
69
|
+
"TextEmbedder",
|
|
70
|
+
"HuggingFaceTextEmbedder",
|
|
71
|
+
"embed_texts",
|
|
72
|
+
"PCAConfig",
|
|
73
|
+
"PCAProjector",
|
|
74
|
+
"VectorProjector",
|
|
75
|
+
"LabelCandidate",
|
|
76
|
+
"LabelRequest",
|
|
77
|
+
"Labeler",
|
|
78
|
+
"FunctionLabeler",
|
|
79
|
+
"LLMLabeler",
|
|
80
|
+
"TreeLoader",
|
|
81
|
+
"FileSystemTreeLoader",
|
|
82
|
+
"JsonTreeLoader",
|
|
83
|
+
"SQLAlchemyContentLoader",
|
|
84
|
+
"SQLAlchemyTreeLoader",
|
|
85
|
+
"SQLiteTreeLoader",
|
|
86
|
+
"MaterializedTreeState",
|
|
87
|
+
"TreeReconciler",
|
|
88
|
+
"DefaultTreeReconciler",
|
|
89
|
+
"TreePersister",
|
|
90
|
+
"FolderTreePersister",
|
|
91
|
+
"FileSystemTreePersister",
|
|
92
|
+
"JsonTreePersister",
|
|
93
|
+
"SQLAlchemyTreePersister",
|
|
94
|
+
"DefaultTreeRepresentation",
|
|
95
|
+
"PartialTree",
|
|
96
|
+
"ContentNode",
|
|
97
|
+
"KeyNode",
|
|
98
|
+
"TreeEdge",
|
|
99
|
+
"NodeEmbedding",
|
|
100
|
+
"NodeAggregate",
|
|
101
|
+
"NodeId",
|
|
102
|
+
"VectorData",
|
|
103
|
+
"partial_tree_from_dict",
|
|
104
|
+
"partial_tree_to_dict",
|
|
105
|
+
"TreeState",
|
|
106
|
+
"TreeStore",
|
|
107
|
+
"FileTreeStore",
|
|
108
|
+
"NullTreeStore",
|
|
109
|
+
"EmbeddingProvider",
|
|
110
|
+
"FakeEmbeddingProvider",
|
|
111
|
+
"OpenAIEmbeddingProvider",
|
|
112
|
+
"SentenceTransformerProvider",
|
|
113
|
+
"Reducer",
|
|
114
|
+
"IdentityReducer",
|
|
115
|
+
"FreezePCAReducer",
|
|
116
|
+
"IncrementalPCAReducer",
|
|
117
|
+
"Tagger",
|
|
118
|
+
"KeywordTagger",
|
|
119
|
+
"LLMTagger",
|
|
120
|
+
"make_tagger",
|
|
121
|
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Deprecated cache compatibility imports.
|
|
2
|
+
|
|
3
|
+
Use loaders plus persisters directly for new code.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .json import JsonTreeCache
|
|
7
|
+
from .model import MaterializedTreeState, TreeCache
|
|
8
|
+
from .sqlalchemy import SQLAlchemyTreeCache
|
|
9
|
+
|
|
10
|
+
__all__ = ["MaterializedTreeState", "TreeCache", "JsonTreeCache", "SQLAlchemyTreeCache"]
|
embed_tree/cache/json.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Deprecated cache compatibility contract."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
from embed_tree.loaders.model import TreeLoader
|
|
8
|
+
from embed_tree.persisters.model import MaterializedTreeState, TreePersister
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class TreeCache(TreeLoader, TreePersister, Protocol):
|
|
13
|
+
"""Deprecated alias for TreeLoader + TreePersister."""
|
embed_tree/config.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Configuration for embed-tree, as a single pydantic object.
|
|
2
|
+
|
|
3
|
+
`TreeConfig` is a plain pydantic `BaseModel` (NOT `BaseSettings`): it is
|
|
4
|
+
constructed explicitly and handed whole to `EmbedTree(config=...)`. It does
|
|
5
|
+
**not** read environment variables — every value must be passed in code, so the
|
|
6
|
+
configuration is always explicit and reproducible.
|
|
7
|
+
|
|
8
|
+
See DESIGN.md §4/§5.3. M1 adds PCA dimensionality reduction (pca_dims) in two
|
|
9
|
+
modes (freeze / incremental); see those sections for the rebalance contract.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Literal
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RebalanceConfig(BaseModel):
|
|
20
|
+
"""When/whether to rebuild the whole tree from its leaves (DESIGN.md §4)."""
|
|
21
|
+
|
|
22
|
+
enabled: bool = True
|
|
23
|
+
every_n_inserts: int | None = 10_000 # auto-rebuild cadence; None disables
|
|
24
|
+
on_demand: bool = True # allow manual tree.rebalance()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LLMConfig(BaseModel):
|
|
28
|
+
"""How to auto-name taxonomy nodes (DESIGN.md §10). Provider "none" uses a
|
|
29
|
+
no-network keyword tagger; "openai" and "local" generate labels with an LLM.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
provider: Literal["none", "openai", "local"] = "none"
|
|
33
|
+
model: str = "gpt-4o-mini" # OpenAI model id, or HF model id when local
|
|
34
|
+
api_key: str | None = None # OpenAI key (explicit; no env)
|
|
35
|
+
base_url: str | None = None # OpenAI-compatible endpoint (e.g. a local server)
|
|
36
|
+
max_samples: int = 15 # member texts shown to the LLM when naming a cluster
|
|
37
|
+
max_label_words: int = 6 # keep labels short and browsable
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TreeConfig(BaseModel):
|
|
41
|
+
"""Top-level knobs. Defaults are tuned for the M0/M1 (<100k items) regime.
|
|
42
|
+
|
|
43
|
+
Constructed in code only — no environment-variable loading.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(protected_namespaces=()) # allow `model_args` name
|
|
47
|
+
|
|
48
|
+
# Defaults are tuned for a human-browsable taxonomy (DESIGN.md §10): a small
|
|
49
|
+
# fan-out and small leaves keep every level readable (<=5 sub-topics, <=10
|
|
50
|
+
# items per leaf). Raise both for a large-scale retrieval index instead.
|
|
51
|
+
max_branches: int = 5 # max sub-topics per level (k for KMeans)
|
|
52
|
+
leaf_capacity: int = 10 # max items in a leaf before it subdivides
|
|
53
|
+
split_algo: str = "kmeans" # M0: "kmeans" only
|
|
54
|
+
|
|
55
|
+
# Distance is always cosine: vectors are L2-normalized and compared with
|
|
56
|
+
# Euclidean (rank-equivalent on the unit sphere). Embeddings encode meaning
|
|
57
|
+
# in direction, not magnitude, so there is no separate distance knob.
|
|
58
|
+
|
|
59
|
+
# --- PCA dimensionality reduction (M1; see DESIGN.md §5.3) -------------
|
|
60
|
+
# Off by default: only worth it at scale (thousands+). At tens of items PCA
|
|
61
|
+
# is meaningless (too few samples) and never reaches pca_warmup anyway.
|
|
62
|
+
pca_dims: int | None = None # None = no reduction (operate in raw space)
|
|
63
|
+
pca_mode: Literal["freeze", "incremental"] = "freeze"
|
|
64
|
+
pca_warmup: int = 1000 # items buffered before the first PCA fit
|
|
65
|
+
pca_batch_size: int = 256 # incremental: partial_fit cadence
|
|
66
|
+
|
|
67
|
+
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
|
|
68
|
+
llm: LLMConfig = Field(default_factory=LLMConfig) # node auto-naming
|
|
69
|
+
model_args: dict = Field(default_factory=dict) # passed through to KMeans
|
|
70
|
+
|
|
71
|
+
@field_validator("max_branches")
|
|
72
|
+
@classmethod
|
|
73
|
+
def _min_branches(cls, v: int) -> int:
|
|
74
|
+
if v < 2:
|
|
75
|
+
raise ValueError("max_branches must be >= 2")
|
|
76
|
+
return v
|
|
77
|
+
|
|
78
|
+
@field_validator("split_algo")
|
|
79
|
+
@classmethod
|
|
80
|
+
def _supported_split(cls, v: str) -> str:
|
|
81
|
+
if v != "kmeans":
|
|
82
|
+
raise NotImplementedError(
|
|
83
|
+
f"split_algo={v!r} arrives in a later milestone; "
|
|
84
|
+
"M0 supports 'kmeans' only"
|
|
85
|
+
)
|
|
86
|
+
return v
|
|
87
|
+
|
|
88
|
+
@model_validator(mode="after")
|
|
89
|
+
def _cross_field(self) -> "TreeConfig":
|
|
90
|
+
if self.leaf_capacity < self.max_branches:
|
|
91
|
+
raise ValueError("leaf_capacity must be >= max_branches")
|
|
92
|
+
if self.pca_dims is not None:
|
|
93
|
+
if self.pca_dims < 2:
|
|
94
|
+
raise ValueError("pca_dims must be >= 2")
|
|
95
|
+
# PCA needs at least n_components samples to fit / partial_fit.
|
|
96
|
+
if self.pca_warmup < self.pca_dims:
|
|
97
|
+
raise ValueError("pca_warmup must be >= pca_dims")
|
|
98
|
+
if self.pca_batch_size < self.pca_dims:
|
|
99
|
+
raise ValueError("pca_batch_size must be >= pca_dims")
|
|
100
|
+
return self
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Open-source local embeddings via Hugging Face sentence-transformers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from embed_tree.providers.local import SentenceTransformerProvider
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HuggingFaceTextEmbedder(SentenceTransformerProvider):
|
|
14
|
+
"""Sentence-transformers embedder with Mac-friendly device selection.
|
|
15
|
+
|
|
16
|
+
The model is downloaded from Hugging Face by sentence-transformers on first
|
|
17
|
+
use and then cached by that stack. On Apple Silicon, device="auto" prefers
|
|
18
|
+
MPS when PyTorch reports it as available; otherwise it falls back to CPU.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model: str = "BAAI/bge-small-en-v1.5",
|
|
24
|
+
*,
|
|
25
|
+
device: str | None = "auto",
|
|
26
|
+
cache_folder: str | Path | None = None,
|
|
27
|
+
model_obj: Any | None = None,
|
|
28
|
+
encode_kwargs: dict[str, Any] | None = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
) -> None:
|
|
31
|
+
resolved_device = _resolve_device(device)
|
|
32
|
+
self.model_name = model
|
|
33
|
+
self.device = resolved_device
|
|
34
|
+
self.cache_folder = None if cache_folder is None else str(cache_folder)
|
|
35
|
+
if model_obj is None:
|
|
36
|
+
try:
|
|
37
|
+
from sentence_transformers import SentenceTransformer
|
|
38
|
+
except ImportError as e: # pragma: no cover
|
|
39
|
+
raise ImportError(
|
|
40
|
+
'HuggingFaceTextEmbedder needs the "local" extra: '
|
|
41
|
+
'pip install "embed-tree[local]"'
|
|
42
|
+
) from e
|
|
43
|
+
model_obj = SentenceTransformer(model, device=resolved_device, cache_folder=self.cache_folder)
|
|
44
|
+
super().__init__(model=model, device=resolved_device, model_obj=model_obj, encode_kwargs=encode_kwargs, **kwargs)
|
|
45
|
+
else:
|
|
46
|
+
super().__init__(model=model, device=resolved_device, model_obj=model_obj, encode_kwargs=encode_kwargs, **kwargs)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _resolve_device(device: str | None) -> str | None:
|
|
50
|
+
if device != "auto":
|
|
51
|
+
return device
|
|
52
|
+
try:
|
|
53
|
+
import torch
|
|
54
|
+
except ImportError:
|
|
55
|
+
return None
|
|
56
|
+
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
|
|
57
|
+
return "mps"
|
|
58
|
+
if torch.cuda.is_available():
|
|
59
|
+
return "cuda"
|
|
60
|
+
return "cpu"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def embed_texts(embedder: Any, texts: list[str]) -> np.ndarray:
|
|
64
|
+
"""Embed a batch through any callable or TextEmbedder-like object."""
|
|
65
|
+
batch_fn = getattr(embedder, "embed_batch", None)
|
|
66
|
+
if callable(batch_fn):
|
|
67
|
+
return np.asarray(batch_fn(texts))
|
|
68
|
+
return np.asarray([embedder(text) for text in texts])
|
|
69
|
+
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Embedding model contracts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Protocol, Sequence, runtime_checkable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
Vector = np.ndarray
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class TextEmbedder(Protocol):
|
|
14
|
+
"""Turn strings into embedding vectors."""
|
|
15
|
+
|
|
16
|
+
def embed(self, text: str) -> Vector:
|
|
17
|
+
"""Embed one string."""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def embed_batch(self, texts: Sequence[str]) -> np.ndarray:
|
|
21
|
+
"""Embed strings in order."""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
def __call__(self, text: str) -> Vector:
|
|
25
|
+
"""Embed one string."""
|
|
26
|
+
...
|
|
27
|
+
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Function-backed labeler."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Iterable
|
|
6
|
+
|
|
7
|
+
from .model import LabelRequest
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FunctionLabeler:
|
|
11
|
+
"""Adapt a cheap user function into the streaming labeler protocol."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, fn: Callable[[LabelRequest], str | Iterable[str]]) -> None:
|
|
14
|
+
self.fn = fn
|
|
15
|
+
|
|
16
|
+
def stream(self, request: LabelRequest) -> Iterable[str]:
|
|
17
|
+
out = self.fn(request)
|
|
18
|
+
if isinstance(out, str):
|
|
19
|
+
yield out
|
|
20
|
+
else:
|
|
21
|
+
yield from out
|
|
22
|
+
|
|
23
|
+
def label(self, request: LabelRequest) -> str:
|
|
24
|
+
return "".join(self.stream(request)).strip()
|
|
25
|
+
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""LLM-backed streaming labeler."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Iterable
|
|
6
|
+
|
|
7
|
+
from embed_tree.config import LLMConfig
|
|
8
|
+
from embed_tree.taggers import LLMTagger
|
|
9
|
+
|
|
10
|
+
from .model import LabelRequest
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMLabeler:
|
|
14
|
+
"""Generate labels from nearby candidates using the existing LLM tagger."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config: LLMConfig | None = None, *, client: Any | None = None, pipeline: Any | None = None) -> None:
|
|
17
|
+
self.config = config or LLMConfig()
|
|
18
|
+
self.tagger = LLMTagger(self.config, client=client, pipeline=pipeline)
|
|
19
|
+
|
|
20
|
+
def stream(self, request: LabelRequest) -> Iterable[str]:
|
|
21
|
+
yield self.label(request)
|
|
22
|
+
|
|
23
|
+
def label(self, request: LabelRequest) -> str:
|
|
24
|
+
texts = [candidate.text for candidate in request.candidates]
|
|
25
|
+
return self.tagger(texts)
|
|
26
|
+
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Labeling model contracts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Iterable, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class LabelCandidate:
|
|
11
|
+
"""Nearby node or item used as context for a label."""
|
|
12
|
+
|
|
13
|
+
id: Any
|
|
14
|
+
text: str
|
|
15
|
+
distance: float | None = None
|
|
16
|
+
payload: Any = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class LabelRequest:
|
|
21
|
+
"""Context for generating a label."""
|
|
22
|
+
|
|
23
|
+
candidates: list[LabelCandidate]
|
|
24
|
+
max_words: int = 6
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@runtime_checkable
|
|
28
|
+
class Labeler(Protocol):
|
|
29
|
+
"""Generate a label from nearby candidates."""
|
|
30
|
+
|
|
31
|
+
def stream(self, request: LabelRequest) -> Iterable[str]:
|
|
32
|
+
"""Yield label chunks."""
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
def label(self, request: LabelRequest) -> str:
|
|
36
|
+
"""Return the full label."""
|
|
37
|
+
...
|
|
38
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Tree loader contracts."""
|
|
2
|
+
|
|
3
|
+
from .filesystem import FileSystemTreeLoader
|
|
4
|
+
from .json import JsonTreeLoader
|
|
5
|
+
from .model import TreeLoader
|
|
6
|
+
from .sqlalchemy_content import SQLAlchemyContentLoader
|
|
7
|
+
from .sqlalchemy import SQLAlchemyTreeLoader
|
|
8
|
+
from .sqlite import SQLiteTreeLoader
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"TreeLoader",
|
|
12
|
+
"FileSystemTreeLoader",
|
|
13
|
+
"JsonTreeLoader",
|
|
14
|
+
"SQLAlchemyContentLoader",
|
|
15
|
+
"SQLAlchemyTreeLoader",
|
|
16
|
+
"SQLiteTreeLoader",
|
|
17
|
+
]
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Filesystem-backed ground-truth loader."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Iterable
|
|
8
|
+
|
|
9
|
+
from embed_tree.representation import ContentNode, KeyNode, PartialTree, TreeEdge
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FileSystemTreeLoader:
|
|
13
|
+
"""Load files under a directory as content nodes.
|
|
14
|
+
|
|
15
|
+
Directory nodes are emitted as ``KeyNode`` records with edges to their child
|
|
16
|
+
directories/files. File node ids are MD5 hashes of their file bytes, so the
|
|
17
|
+
same file keeps its identity when it moves locally.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
root: str | Path,
|
|
23
|
+
*,
|
|
24
|
+
include_suffixes: Iterable[str] | None = None,
|
|
25
|
+
encoding: str = "utf-8",
|
|
26
|
+
include_hidden: bool = False,
|
|
27
|
+
) -> None:
|
|
28
|
+
self.root = Path(root)
|
|
29
|
+
self.include_suffixes = None if include_suffixes is None else {s.lower() for s in include_suffixes}
|
|
30
|
+
self.encoding = encoding
|
|
31
|
+
self.include_hidden = include_hidden
|
|
32
|
+
|
|
33
|
+
def load(self) -> PartialTree | None:
|
|
34
|
+
if not self.root.exists():
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
tree = PartialTree(metadata={"source": "filesystem", "root": str(self.root)})
|
|
38
|
+
root_id = "."
|
|
39
|
+
tree.key_nodes.append(KeyNode(id=root_id, label=self.root.name or str(self.root)))
|
|
40
|
+
|
|
41
|
+
for path in sorted(self.root.rglob("*")):
|
|
42
|
+
rel = path.relative_to(self.root).as_posix()
|
|
43
|
+
if not self.include_hidden and any(part.startswith(".") for part in path.relative_to(self.root).parts):
|
|
44
|
+
continue
|
|
45
|
+
parent = path.parent.relative_to(self.root).as_posix() if path.parent != self.root else root_id
|
|
46
|
+
if path.is_dir():
|
|
47
|
+
tree.key_nodes.append(KeyNode(id=rel, label=path.name))
|
|
48
|
+
tree.edges.append(TreeEdge(parent_id=parent, child_id=rel))
|
|
49
|
+
continue
|
|
50
|
+
if not path.is_file() or not self._included(path):
|
|
51
|
+
continue
|
|
52
|
+
file_id = _file_md5(path)
|
|
53
|
+
try:
|
|
54
|
+
content = path.read_text(encoding=self.encoding)
|
|
55
|
+
except UnicodeDecodeError:
|
|
56
|
+
continue
|
|
57
|
+
tree.content_nodes.append(
|
|
58
|
+
ContentNode(
|
|
59
|
+
id=file_id,
|
|
60
|
+
content=content,
|
|
61
|
+
text=path.stem,
|
|
62
|
+
payload={
|
|
63
|
+
"path": str(path),
|
|
64
|
+
"relative_path": rel,
|
|
65
|
+
"filename": path.name,
|
|
66
|
+
},
|
|
67
|
+
version=file_id,
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
tree.edges.append(TreeEdge(parent_id=parent, child_id=file_id))
|
|
71
|
+
|
|
72
|
+
return tree
|
|
73
|
+
|
|
74
|
+
def _included(self, path: Path) -> bool:
|
|
75
|
+
return self.include_suffixes is None or path.suffix.lower() in self.include_suffixes
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _file_md5(path: Path) -> str:
|
|
79
|
+
digest = hashlib.md5()
|
|
80
|
+
with path.open("rb") as f:
|
|
81
|
+
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
|
82
|
+
digest.update(chunk)
|
|
83
|
+
return digest.hexdigest()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""JSON-backed tree loader."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from embed_tree.persisters.model import MaterializedTreeState
|
|
11
|
+
from embed_tree.representation import PartialTree
|
|
12
|
+
from embed_tree.representation.default import partial_tree_from_dict, partial_tree_to_dict
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class JsonTreeLoader:
|
|
16
|
+
"""Load/save a PartialTree or materialized state as one JSON file."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, path: str | Path) -> None:
|
|
19
|
+
self.path = Path(path)
|
|
20
|
+
self.post_init()
|
|
21
|
+
|
|
22
|
+
def post_init(self) -> None:
|
|
23
|
+
"""Hook for implementations that need setup after construction."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def load(self) -> PartialTree | MaterializedTreeState | None:
|
|
27
|
+
if not self.path.exists():
|
|
28
|
+
return None
|
|
29
|
+
with self.path.open("r", encoding="utf-8") as f:
|
|
30
|
+
data = json.load(f)
|
|
31
|
+
if data.get("kind") == "partial_tree":
|
|
32
|
+
return partial_tree_from_dict(data["tree"])
|
|
33
|
+
if data.get("kind") == "materialized_tree_state":
|
|
34
|
+
return data["state"]
|
|
35
|
+
return data
|
|
36
|
+
|
|
37
|
+
def save(self, state: PartialTree | MaterializedTreeState) -> None:
|
|
38
|
+
if isinstance(state, PartialTree):
|
|
39
|
+
payload: dict[str, Any] = {"kind": "partial_tree", "tree": partial_tree_to_dict(state)}
|
|
40
|
+
else:
|
|
41
|
+
payload = {"kind": "materialized_tree_state", "state": state}
|
|
42
|
+
|
|
43
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
tmp = self.path.with_name(f"{self.path.name}.tmp.{os.getpid()}")
|
|
45
|
+
with tmp.open("w", encoding="utf-8") as f:
|
|
46
|
+
json.dump(payload, f)
|
|
47
|
+
f.flush()
|
|
48
|
+
os.fsync(f.fileno())
|
|
49
|
+
os.replace(tmp, self.path)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Abstract loader contract."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
from embed_tree.representation import PartialTree
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@runtime_checkable
|
|
11
|
+
class TreeLoader(Protocol):
|
|
12
|
+
"""Load a partial tree from any source.
|
|
13
|
+
|
|
14
|
+
Ground truth and reusable-state inputs share this shape. Their semantics
|
|
15
|
+
come from the argument position where the loader is used.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def load(self) -> PartialTree | None:
|
|
19
|
+
"""Return loaded tree data, or None if the source is empty."""
|
|
20
|
+
...
|