segment_classifier 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- segment_classifier/__init__.py +4 -0
- segment_classifier/cache/__init__.py +4 -0
- segment_classifier/cache/l1_cache.py +71 -0
- segment_classifier/cache/l2_cache.py +157 -0
- segment_classifier/config.py +53 -0
- segment_classifier/models.py +142 -0
- segment_classifier/pipeline.py +173 -0
- segment_classifier/stages/__init__.py +6 -0
- segment_classifier/stages/fingerprint.py +10 -0
- segment_classifier/stages/fuzzy_cluster.py +101 -0
- segment_classifier/stages/llm_classifier.py +271 -0
- segment_classifier/stages/rule_based.py +287 -0
- segment_classifier/utils/__init__.py +3 -0
- segment_classifier/utils/html_normalizer.py +165 -0
- segment_classifier-0.1.0.dist-info/METADATA +95 -0
- segment_classifier-0.1.0.dist-info/RECORD +17 -0
- segment_classifier-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import aiofiles
|
|
5
|
+
from pydantic import ValidationError
|
|
6
|
+
from segment_classifier.models import FingerprintRecord
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class L1FingerprintCache:
|
|
10
|
+
def __init__(self, cache_path: str, auto_persist_every: int = 50):
|
|
11
|
+
self._path = Path(cache_path)
|
|
12
|
+
self._store: dict[str, FingerprintRecord] = {}
|
|
13
|
+
self._lock = asyncio.Lock()
|
|
14
|
+
self._write_count = 0
|
|
15
|
+
self._auto_persist_every = auto_persist_every
|
|
16
|
+
|
|
17
|
+
async def load(self) -> None:
|
|
18
|
+
if not self._path.exists():
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
async with aiofiles.open(self._path, "r", encoding="utf-8") as f:
|
|
22
|
+
content = await f.read()
|
|
23
|
+
if not content.strip():
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
data = json.loads(content)
|
|
28
|
+
for key, val in data.items():
|
|
29
|
+
self._store[key] = FingerprintRecord.model_validate(val)
|
|
30
|
+
except (json.JSONDecodeError, ValidationError) as e:
|
|
31
|
+
# Log error in real app, we'll just ignore and start fresh here
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
async def get(self, fingerprint_hash: str) -> FingerprintRecord | None:
|
|
35
|
+
async with self._lock:
|
|
36
|
+
return self._store.get(fingerprint_hash)
|
|
37
|
+
|
|
38
|
+
async def set(self, fingerprint_hash: str, record: FingerprintRecord) -> None:
|
|
39
|
+
async with self._lock:
|
|
40
|
+
self._store[fingerprint_hash] = record
|
|
41
|
+
self._write_count += 1
|
|
42
|
+
if self._write_count >= self._auto_persist_every:
|
|
43
|
+
self._write_count = 0
|
|
44
|
+
await self._persist_unsafe()
|
|
45
|
+
|
|
46
|
+
async def increment_hit(self, fingerprint_hash: str) -> None:
|
|
47
|
+
async with self._lock:
|
|
48
|
+
record = self._store.get(fingerprint_hash)
|
|
49
|
+
if record:
|
|
50
|
+
record.hit_count += 1
|
|
51
|
+
self._store[fingerprint_hash] = record
|
|
52
|
+
self._write_count += 1
|
|
53
|
+
if self._write_count >= self._auto_persist_every:
|
|
54
|
+
self._write_count = 0
|
|
55
|
+
await self._persist_unsafe()
|
|
56
|
+
|
|
57
|
+
async def _persist_unsafe(self) -> None:
|
|
58
|
+
"""Called internally when lock is already acquired."""
|
|
59
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
60
|
+
data = {k: v.model_dump(mode="json") for k, v in self._store.items()}
|
|
61
|
+
async with aiofiles.open(self._path, "w", encoding="utf-8") as f:
|
|
62
|
+
await f.write(json.dumps(data, indent=2))
|
|
63
|
+
|
|
64
|
+
async def persist(self) -> None:
|
|
65
|
+
async with self._lock:
|
|
66
|
+
self._write_count = 0
|
|
67
|
+
await self._persist_unsafe()
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def size(self) -> int:
|
|
71
|
+
return len(self._store)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import uuid
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import aiofiles
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import ValidationError
|
|
8
|
+
from segment_classifier.models import ClusterRecord, ComponentType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class L2FuzzyCache:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
cache_path: str,
|
|
15
|
+
embeddings_path: str,
|
|
16
|
+
similarity_threshold: float = 0.85,
|
|
17
|
+
max_cluster_size: int = 50,
|
|
18
|
+
persist_on_update: bool = True
|
|
19
|
+
):
|
|
20
|
+
self._path = Path(cache_path)
|
|
21
|
+
self._embeddings_path = Path(embeddings_path)
|
|
22
|
+
self._similarity_threshold = similarity_threshold
|
|
23
|
+
self._max_cluster_size = max_cluster_size
|
|
24
|
+
self._persist_on_update = persist_on_update
|
|
25
|
+
|
|
26
|
+
self._store: list[ClusterRecord] = []
|
|
27
|
+
# Matrix storing centroids, parallel to self._store
|
|
28
|
+
self._centroids: np.ndarray | None = None
|
|
29
|
+
|
|
30
|
+
self._lock = asyncio.Lock()
|
|
31
|
+
|
|
32
|
+
async def load(self) -> None:
|
|
33
|
+
if not self._path.exists() or not self._embeddings_path.exists():
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
async with aiofiles.open(self._path, "r", encoding="utf-8") as f:
|
|
37
|
+
content = await f.read()
|
|
38
|
+
if content.strip():
|
|
39
|
+
try:
|
|
40
|
+
data = json.loads(content)
|
|
41
|
+
self._store = [ClusterRecord.model_validate(val) for val in data]
|
|
42
|
+
except (json.JSONDecodeError, ValidationError):
|
|
43
|
+
self._store = []
|
|
44
|
+
|
|
45
|
+
if self._store:
|
|
46
|
+
try:
|
|
47
|
+
self._centroids = np.load(str(self._embeddings_path))
|
|
48
|
+
if len(self._store) != self._centroids.shape[0]:
|
|
49
|
+
# Mismatch between JSON and npy, reset
|
|
50
|
+
self._store = []
|
|
51
|
+
self._centroids = None
|
|
52
|
+
except Exception:
|
|
53
|
+
self._store = []
|
|
54
|
+
self._centroids = None
|
|
55
|
+
|
|
56
|
+
async def find_nearest(self, vector: list[float], threshold: float | None = None) -> ClusterRecord | None:
|
|
57
|
+
async with self._lock:
|
|
58
|
+
if not self._store or self._centroids is None:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
query = np.array(vector)
|
|
62
|
+
query_norm = np.linalg.norm(query)
|
|
63
|
+
if query_norm == 0:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
# Cosine similarity
|
|
67
|
+
dot_products = np.dot(self._centroids, query)
|
|
68
|
+
centroid_norms = np.linalg.norm(self._centroids, axis=1)
|
|
69
|
+
# Avoid division by zero
|
|
70
|
+
centroid_norms[centroid_norms == 0] = 1
|
|
71
|
+
|
|
72
|
+
similarities = dot_products / (centroid_norms * query_norm)
|
|
73
|
+
|
|
74
|
+
best_idx = np.argmax(similarities)
|
|
75
|
+
best_sim = similarities[best_idx]
|
|
76
|
+
|
|
77
|
+
check_threshold = threshold if threshold is not None else self._similarity_threshold
|
|
78
|
+
|
|
79
|
+
if best_sim >= check_threshold:
|
|
80
|
+
return self._store[best_idx]
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
async def add_to_cluster(self, cluster_id: str, fingerprint_hash: str, vector: list[float]) -> None:
|
|
84
|
+
async with self._lock:
|
|
85
|
+
idx = next((i for i, c in enumerate(self._store) if c.cluster_id == cluster_id), None)
|
|
86
|
+
if idx is not None and self._centroids is not None:
|
|
87
|
+
cluster = self._store[idx]
|
|
88
|
+
|
|
89
|
+
# Check size
|
|
90
|
+
if len(cluster.member_fingerprints) < self._max_cluster_size:
|
|
91
|
+
if fingerprint_hash not in cluster.member_fingerprints:
|
|
92
|
+
cluster.member_fingerprints.append(fingerprint_hash)
|
|
93
|
+
|
|
94
|
+
# Update centroid (running mean)
|
|
95
|
+
n = len(cluster.member_fingerprints)
|
|
96
|
+
old_centroid = self._centroids[idx]
|
|
97
|
+
new_vec = np.array(vector)
|
|
98
|
+
# (old_centroid * (n-1) + new_vec) / n
|
|
99
|
+
new_centroid = (old_centroid * (n - 1) + new_vec) / n
|
|
100
|
+
|
|
101
|
+
self._centroids[idx] = new_centroid
|
|
102
|
+
cluster.centroid_vector = new_centroid.tolist()
|
|
103
|
+
|
|
104
|
+
if self._persist_on_update:
|
|
105
|
+
await self._persist_unsafe()
|
|
106
|
+
|
|
107
|
+
async def create_cluster(
|
|
108
|
+
self,
|
|
109
|
+
fingerprint_hash: str,
|
|
110
|
+
vector: list[float],
|
|
111
|
+
component_type: ComponentType,
|
|
112
|
+
confidence: float,
|
|
113
|
+
) -> ClusterRecord:
|
|
114
|
+
async with self._lock:
|
|
115
|
+
cluster_id = str(uuid.uuid4())
|
|
116
|
+
record = ClusterRecord(
|
|
117
|
+
cluster_id=cluster_id,
|
|
118
|
+
centroid_vector=vector,
|
|
119
|
+
component_type=component_type,
|
|
120
|
+
confidence=confidence,
|
|
121
|
+
member_fingerprints=[fingerprint_hash]
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
self._store.append(record)
|
|
125
|
+
|
|
126
|
+
new_vec = np.array([vector])
|
|
127
|
+
if self._centroids is None:
|
|
128
|
+
self._centroids = new_vec
|
|
129
|
+
else:
|
|
130
|
+
self._centroids = np.vstack([self._centroids, new_vec])
|
|
131
|
+
|
|
132
|
+
if self._persist_on_update:
|
|
133
|
+
await self._persist_unsafe()
|
|
134
|
+
|
|
135
|
+
return record
|
|
136
|
+
|
|
137
|
+
async def _persist_unsafe(self) -> None:
|
|
138
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
139
|
+
|
|
140
|
+
# Write JSON
|
|
141
|
+
data = [c.model_dump(mode="json") for c in self._store]
|
|
142
|
+
async with aiofiles.open(self._path, "w", encoding="utf-8") as f:
|
|
143
|
+
await f.write(json.dumps(data, indent=2))
|
|
144
|
+
|
|
145
|
+
# Write npy
|
|
146
|
+
if self._centroids is not None:
|
|
147
|
+
# We can't do aiofiles easily for numpy, doing it sync for now
|
|
148
|
+
# as it's a small matrix and we use it as memory store
|
|
149
|
+
np.save(str(self._embeddings_path), self._centroids)
|
|
150
|
+
|
|
151
|
+
async def persist(self) -> None:
|
|
152
|
+
async with self._lock:
|
|
153
|
+
await self._persist_unsafe()
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def size(self) -> int:
|
|
157
|
+
return len(self._store)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ModelFeatureConfig(BaseModel):
|
|
6
|
+
"""
|
|
7
|
+
Feature-based LLM model routing.
|
|
8
|
+
|
|
9
|
+
Selection priority (highest to lowest):
|
|
10
|
+
1. high_complexity → use for ambiguous, deeply nested, multi-role segments
|
|
11
|
+
2. standard → default for most unknown segments
|
|
12
|
+
3. fast → simple segments with weak signals but not rule-matchable
|
|
13
|
+
|
|
14
|
+
Complexity is determined by:
|
|
15
|
+
- dom_depth > threshold
|
|
16
|
+
- child_tag_counts diversity (many unique tags = complex)
|
|
17
|
+
- text_density_ratio (very high or very low = complex)
|
|
18
|
+
- sibling_count == 0 (one-off sections = complex)
|
|
19
|
+
"""
|
|
20
|
+
high_complexity_model: str = "anthropic/claude-opus-4"
|
|
21
|
+
standard_model: str = "anthropic/claude-sonnet-4-5"
|
|
22
|
+
fast_model: str = "anthropic/claude-haiku-4-5"
|
|
23
|
+
|
|
24
|
+
high_complexity_dom_depth_threshold: int = 6
|
|
25
|
+
high_complexity_unique_tag_threshold: int = 8
|
|
26
|
+
fast_model_max_dom_depth: int = 3
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CacheConfig(BaseModel):
|
|
30
|
+
l1_cache_path: str = ".cache/l1_fingerprints.json"
|
|
31
|
+
l2_cache_path: str = ".cache/l2_clusters.json"
|
|
32
|
+
l2_embeddings_path: str = ".cache/l2_embeddings.npy"
|
|
33
|
+
l2_similarity_threshold: float = 0.85
|
|
34
|
+
l2_max_cluster_size: int = 50
|
|
35
|
+
persist_on_update: bool = True
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ClassifierSettings(BaseSettings):
|
|
39
|
+
model_config = SettingsConfigDict(env_file=".env", env_prefix="CLASSIFIER_")
|
|
40
|
+
|
|
41
|
+
# LiteLLM
|
|
42
|
+
litellm_api_key: str = ""
|
|
43
|
+
litellm_batch_size: int = 20 # max segments per LLM batch call
|
|
44
|
+
litellm_max_concurrent_batches: int = 5
|
|
45
|
+
litellm_timeout_seconds: int = 60
|
|
46
|
+
|
|
47
|
+
# Pipeline
|
|
48
|
+
rule_based_confidence_threshold: float = 0.90
|
|
49
|
+
l1_min_confidence: float = 0.85
|
|
50
|
+
l2_min_confidence: float = 0.75
|
|
51
|
+
|
|
52
|
+
model_routing: ModelFeatureConfig = ModelFeatureConfig()
|
|
53
|
+
cache: CacheConfig = CacheConfig()
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ClassificationStage(str, Enum):
|
|
7
|
+
RULE_BASED = "rule_based"
|
|
8
|
+
L1_EXACT_CACHE = "l1_exact_cache"
|
|
9
|
+
L2_FUZZY_CACHE = "l2_fuzzy_cache"
|
|
10
|
+
LLM = "llm"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SegmentPosition(str, Enum):
|
|
14
|
+
TOP = "top" # top 5% of page
|
|
15
|
+
BOTTOM = "bottom" # bottom 10% of page
|
|
16
|
+
MIDDLE = "middle"
|
|
17
|
+
UNKNOWN = "unknown"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ComponentType(str, Enum):
|
|
21
|
+
# Layout
|
|
22
|
+
LAYOUT_HEADER = "layout.header"
|
|
23
|
+
LAYOUT_FOOTER = "layout.footer"
|
|
24
|
+
LAYOUT_NAV = "layout.nav"
|
|
25
|
+
LAYOUT_SIDEBAR = "layout.sidebar"
|
|
26
|
+
LAYOUT_BREADCRUMB = "layout.breadcrumb"
|
|
27
|
+
|
|
28
|
+
# Collections
|
|
29
|
+
COLLECTION_PRODUCT_CARD = "collection.product_card"
|
|
30
|
+
COLLECTION_PRODUCT_LIST = "collection.product_list"
|
|
31
|
+
COLLECTION_BLOG_CARD = "collection.blog_card"
|
|
32
|
+
COLLECTION_BLOG_LIST = "collection.blog_list"
|
|
33
|
+
COLLECTION_NEWS_ITEM = "collection.news_item"
|
|
34
|
+
COLLECTION_NEWS_LIST = "collection.news_list"
|
|
35
|
+
|
|
36
|
+
# Sections
|
|
37
|
+
SECTION_HERO = "section.hero"
|
|
38
|
+
SECTION_FEATURE_GRID = "section.feature_grid"
|
|
39
|
+
SECTION_TESTIMONIAL = "section.testimonial"
|
|
40
|
+
SECTION_CTA = "section.cta"
|
|
41
|
+
SECTION_FAQ = "section.faq"
|
|
42
|
+
SECTION_PRICING = "section.pricing"
|
|
43
|
+
|
|
44
|
+
# UI Elements
|
|
45
|
+
UI_FORM = "ui.form"
|
|
46
|
+
UI_MODAL = "ui.modal"
|
|
47
|
+
UI_TABLE = "ui.table"
|
|
48
|
+
UI_CAROUSEL = "ui.carousel"
|
|
49
|
+
UI_PAGINATION = "ui.pagination"
|
|
50
|
+
UI_SEARCH = "ui.search"
|
|
51
|
+
|
|
52
|
+
# Content
|
|
53
|
+
CONTENT_ARTICLE = "content.article"
|
|
54
|
+
CONTENT_RICH_TEXT = "content.rich_text"
|
|
55
|
+
CONTENT_MEDIA = "content.media"
|
|
56
|
+
|
|
57
|
+
UNKNOWN = "unknown"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class InputSegment(BaseModel):
|
|
61
|
+
"""Raw segment from the page-segmenter tool."""
|
|
62
|
+
segment_id: str
|
|
63
|
+
page_url: str
|
|
64
|
+
page_slug: str
|
|
65
|
+
raw_html: str
|
|
66
|
+
text_content: str
|
|
67
|
+
position_hint: SegmentPosition = SegmentPosition.UNKNOWN
|
|
68
|
+
dom_position: str = "" # CSS selector path e.g. "main > section:nth-child(2)"
|
|
69
|
+
sibling_count: int = 0 # how many same-fingerprint siblings on same page
|
|
70
|
+
url_path_segments: list[str] = Field(default_factory=list) # e.g. ["products", "shoes"]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ClassifiedSegment(BaseModel):
|
|
74
|
+
"""Segment with classification result and metadata."""
|
|
75
|
+
segment_id: str
|
|
76
|
+
page_url: str
|
|
77
|
+
page_slug: str
|
|
78
|
+
raw_html: str
|
|
79
|
+
text_content: str
|
|
80
|
+
position_hint: SegmentPosition
|
|
81
|
+
|
|
82
|
+
# Classification output
|
|
83
|
+
component_type: ComponentType
|
|
84
|
+
classification_stage: ClassificationStage
|
|
85
|
+
confidence: float = Field(ge=0.0, le=1.0)
|
|
86
|
+
|
|
87
|
+
# Fingerprint computed during pipeline
|
|
88
|
+
fingerprint_hash: str = ""
|
|
89
|
+
cluster_id: str | None = None
|
|
90
|
+
|
|
91
|
+
# LLM metadata (populated only for stage=LLM)
|
|
92
|
+
llm_model_used: str | None = None
|
|
93
|
+
llm_raw_response: str | None = None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class FingerprintRecord(BaseModel):
|
|
97
|
+
"""Stored in L1 cache: fingerprint → classification."""
|
|
98
|
+
fingerprint_hash: str
|
|
99
|
+
component_type: ComponentType
|
|
100
|
+
confidence: float
|
|
101
|
+
hit_count: int = 1
|
|
102
|
+
example_segment_id: str = ""
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ClusterRecord(BaseModel):
|
|
106
|
+
"""Stored in L2 cache: cluster of similar fingerprints."""
|
|
107
|
+
cluster_id: str
|
|
108
|
+
centroid_vector: list[float]
|
|
109
|
+
component_type: ComponentType
|
|
110
|
+
confidence: float
|
|
111
|
+
member_fingerprints: list[str] = Field(default_factory=list)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class LLMClassificationRequest(BaseModel):
|
|
115
|
+
"""Batch item sent to LLM."""
|
|
116
|
+
segment_id: str
|
|
117
|
+
fingerprint_hash: str
|
|
118
|
+
normalized_html: str # skeleton only, no content
|
|
119
|
+
position_hint: SegmentPosition
|
|
120
|
+
sibling_count: int
|
|
121
|
+
url_hints: list[str]
|
|
122
|
+
dom_depth: int
|
|
123
|
+
child_tag_counts: dict[str, int]
|
|
124
|
+
text_density_ratio: float
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class LLMClassificationResult(BaseModel):
|
|
128
|
+
"""Parsed result from LLM for one segment."""
|
|
129
|
+
segment_id: str
|
|
130
|
+
component_type: ComponentType
|
|
131
|
+
confidence: float
|
|
132
|
+
reasoning: str
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class PipelineResult(BaseModel):
|
|
136
|
+
"""Final output of the full classification pipeline run."""
|
|
137
|
+
total_segments: int
|
|
138
|
+
classified: list[ClassifiedSegment]
|
|
139
|
+
stage_breakdown: dict[ClassificationStage, int]
|
|
140
|
+
llm_calls_made: int
|
|
141
|
+
llm_model_usage: dict[str, int] # model_name → call count
|
|
142
|
+
cache_hit_rate: float
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from segment_classifier.models import (
|
|
4
|
+
InputSegment, ClassifiedSegment, PipelineResult, ClassificationStage, FingerprintRecord
|
|
5
|
+
)
|
|
6
|
+
from segment_classifier.config import ClassifierSettings
|
|
7
|
+
from segment_classifier.utils.html_normalizer import NormalizedSegment
|
|
8
|
+
from segment_classifier.stages.rule_based import RuleBasedClassifier
|
|
9
|
+
from segment_classifier.stages.fingerprint import compute_fingerprint
|
|
10
|
+
from segment_classifier.stages.fuzzy_cluster import FuzzyClusterStage
|
|
11
|
+
from segment_classifier.stages.llm_classifier import LLMBatchClassifier
|
|
12
|
+
from segment_classifier.cache.l1_cache import L1FingerprintCache
|
|
13
|
+
from segment_classifier.cache.l2_cache import L2FuzzyCache
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
class ClassifierPipeline:
|
|
18
|
+
def __init__(self, settings: ClassifierSettings):
|
|
19
|
+
self.settings = settings
|
|
20
|
+
self.rule_classifier = RuleBasedClassifier(
|
|
21
|
+
confidence_threshold=settings.rule_based_confidence_threshold
|
|
22
|
+
)
|
|
23
|
+
self.l1_cache = L1FingerprintCache(
|
|
24
|
+
cache_path=settings.cache.l1_cache_path,
|
|
25
|
+
auto_persist_every=50
|
|
26
|
+
)
|
|
27
|
+
self.l2_cache = L2FuzzyCache(
|
|
28
|
+
cache_path=settings.cache.l2_cache_path,
|
|
29
|
+
embeddings_path=settings.cache.l2_embeddings_path,
|
|
30
|
+
similarity_threshold=settings.cache.l2_similarity_threshold,
|
|
31
|
+
max_cluster_size=settings.cache.l2_max_cluster_size,
|
|
32
|
+
persist_on_update=settings.cache.persist_on_update
|
|
33
|
+
)
|
|
34
|
+
self.fuzzy_stage = FuzzyClusterStage(self.l2_cache, settings.cache)
|
|
35
|
+
self.llm_classifier = LLMBatchClassifier(settings)
|
|
36
|
+
|
|
37
|
+
async def initialize(self) -> None:
|
|
38
|
+
"""Load caches from disk."""
|
|
39
|
+
await asyncio.gather(
|
|
40
|
+
self.l1_cache.load(),
|
|
41
|
+
self.l2_cache.load()
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
async def shutdown(self) -> None:
|
|
45
|
+
"""Persist caches to disk."""
|
|
46
|
+
await asyncio.gather(
|
|
47
|
+
self.l1_cache.persist(),
|
|
48
|
+
self.l2_cache.persist()
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
async def run(self, segments: list[InputSegment]) -> PipelineResult:
|
|
52
|
+
total_segments = len(segments)
|
|
53
|
+
classified: list[ClassifiedSegment] = []
|
|
54
|
+
|
|
55
|
+
# Precompute fingerprints concurrently
|
|
56
|
+
loop = asyncio.get_running_loop()
|
|
57
|
+
|
|
58
|
+
# We can run compute_fingerprint in an executor or direct async
|
|
59
|
+
# To avoid blocking event loop for too long if many segments, use executor
|
|
60
|
+
fingerprints_list = await asyncio.gather(
|
|
61
|
+
*[loop.run_in_executor(None, compute_fingerprint, seg) for seg in segments]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
fingerprints: dict[str, tuple[NormalizedSegment, str]] = {
|
|
65
|
+
seg.segment_id: fp for seg, fp in zip(segments, fingerprints_list)
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
pending = segments.copy()
|
|
69
|
+
|
|
70
|
+
# Stage 1: Rule-based
|
|
71
|
+
next_pending = []
|
|
72
|
+
for segment in pending:
|
|
73
|
+
normalized, fp_hash = fingerprints[segment.segment_id]
|
|
74
|
+
result = self.rule_classifier.classify(segment, normalized)
|
|
75
|
+
if result:
|
|
76
|
+
classified.append(result)
|
|
77
|
+
await self.l1_cache.set(fp_hash, FingerprintRecord(
|
|
78
|
+
fingerprint_hash=fp_hash,
|
|
79
|
+
component_type=result.component_type,
|
|
80
|
+
confidence=result.confidence,
|
|
81
|
+
example_segment_id=segment.segment_id
|
|
82
|
+
))
|
|
83
|
+
else:
|
|
84
|
+
next_pending.append(segment)
|
|
85
|
+
pending = next_pending
|
|
86
|
+
|
|
87
|
+
# Stage 2: L1 Exact Cache
|
|
88
|
+
next_pending = []
|
|
89
|
+
for segment in pending:
|
|
90
|
+
normalized, fp_hash = fingerprints[segment.segment_id]
|
|
91
|
+
record = await self.l1_cache.get(fp_hash)
|
|
92
|
+
if record and record.confidence >= self.settings.l1_min_confidence:
|
|
93
|
+
classified.append(ClassifiedSegment(
|
|
94
|
+
segment_id=segment.segment_id,
|
|
95
|
+
page_url=segment.page_url,
|
|
96
|
+
page_slug=segment.page_slug,
|
|
97
|
+
raw_html=segment.raw_html,
|
|
98
|
+
text_content=segment.text_content,
|
|
99
|
+
position_hint=segment.position_hint,
|
|
100
|
+
component_type=record.component_type,
|
|
101
|
+
classification_stage=ClassificationStage.L1_EXACT_CACHE,
|
|
102
|
+
confidence=record.confidence,
|
|
103
|
+
fingerprint_hash=fp_hash
|
|
104
|
+
))
|
|
105
|
+
await self.l1_cache.increment_hit(fp_hash)
|
|
106
|
+
else:
|
|
107
|
+
next_pending.append(segment)
|
|
108
|
+
pending = next_pending
|
|
109
|
+
|
|
110
|
+
# Stage 3: L2 Fuzzy Cache
|
|
111
|
+
next_pending = []
|
|
112
|
+
for segment in pending:
|
|
113
|
+
normalized, fp_hash = fingerprints[segment.segment_id]
|
|
114
|
+
result = await self.fuzzy_stage.classify(segment, normalized, fp_hash)
|
|
115
|
+
if result and result.confidence >= self.settings.l2_min_confidence:
|
|
116
|
+
classified.append(result)
|
|
117
|
+
else:
|
|
118
|
+
next_pending.append(segment)
|
|
119
|
+
pending = next_pending
|
|
120
|
+
|
|
121
|
+
# Stage 4: LLM Batch
|
|
122
|
+
llm_calls_made = 0
|
|
123
|
+
if pending:
|
|
124
|
+
llm_items = [
|
|
125
|
+
(seg, fingerprints[seg.segment_id][0], fingerprints[seg.segment_id][1])
|
|
126
|
+
for seg in pending
|
|
127
|
+
]
|
|
128
|
+
llm_results = await self.llm_classifier.classify_batch(llm_items)
|
|
129
|
+
|
|
130
|
+
# For each LLM result, register in L1 + L2
|
|
131
|
+
for seg, result in zip(pending, llm_results):
|
|
132
|
+
normalized, fp_hash = fingerprints[seg.segment_id]
|
|
133
|
+
await self.l1_cache.set(fp_hash, FingerprintRecord(
|
|
134
|
+
fingerprint_hash=fp_hash,
|
|
135
|
+
component_type=result.component_type,
|
|
136
|
+
confidence=result.confidence,
|
|
137
|
+
example_segment_id=seg.segment_id
|
|
138
|
+
))
|
|
139
|
+
await self.fuzzy_stage.register(
|
|
140
|
+
fingerprint_hash=fp_hash,
|
|
141
|
+
normalized=normalized,
|
|
142
|
+
component_type=result.component_type,
|
|
143
|
+
confidence=result.confidence
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
classified.extend(llm_results)
|
|
147
|
+
|
|
148
|
+
# Calculate total LLM batch calls
|
|
149
|
+
grouped_by_model: dict[str, int] = {}
|
|
150
|
+
for item in llm_items:
|
|
151
|
+
seg, norm, _ = item
|
|
152
|
+
model = self.llm_classifier.select_model(norm, seg)
|
|
153
|
+
grouped_by_model[model] = grouped_by_model.get(model, 0) + 1
|
|
154
|
+
|
|
155
|
+
for count in grouped_by_model.values():
|
|
156
|
+
import math
|
|
157
|
+
llm_calls_made += math.ceil(count / self.settings.litellm_batch_size)
|
|
158
|
+
|
|
159
|
+
stage_breakdown = {stage: 0 for stage in ClassificationStage}
|
|
160
|
+
for c in classified:
|
|
161
|
+
stage_breakdown[c.classification_stage] += 1
|
|
162
|
+
|
|
163
|
+
cache_hits = stage_breakdown[ClassificationStage.L1_EXACT_CACHE] + stage_breakdown[ClassificationStage.L2_FUZZY_CACHE]
|
|
164
|
+
cache_hit_rate = cache_hits / total_segments if total_segments > 0 else 0.0
|
|
165
|
+
|
|
166
|
+
return PipelineResult(
|
|
167
|
+
total_segments=total_segments,
|
|
168
|
+
classified=classified,
|
|
169
|
+
stage_breakdown=stage_breakdown,
|
|
170
|
+
llm_calls_made=llm_calls_made,
|
|
171
|
+
llm_model_usage=self.llm_classifier.model_usage,
|
|
172
|
+
cache_hit_rate=cache_hit_rate
|
|
173
|
+
)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .rule_based import RuleBasedClassifier
|
|
2
|
+
from .fingerprint import compute_fingerprint
|
|
3
|
+
from .fuzzy_cluster import FuzzyClusterStage
|
|
4
|
+
from .llm_classifier import LLMBatchClassifier
|
|
5
|
+
|
|
6
|
+
__all__ = ["RuleBasedClassifier", "compute_fingerprint", "FuzzyClusterStage", "LLMBatchClassifier"]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from segment_classifier.models import InputSegment
|
|
2
|
+
from segment_classifier.utils.html_normalizer import normalize_segment, NormalizedSegment
|
|
3
|
+
|
|
4
|
+
def compute_fingerprint(segment: InputSegment) -> tuple[NormalizedSegment, str]:
|
|
5
|
+
"""
|
|
6
|
+
Computes NormalizedSegment + fingerprint_hash for a given InputSegment.
|
|
7
|
+
"""
|
|
8
|
+
normalized = normalize_segment(segment.raw_html, segment.text_content)
|
|
9
|
+
fingerprint_hash = normalized.fingerprint_hash()
|
|
10
|
+
return normalized, fingerprint_hash
|