segment_classifier 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/PKG-INFO +1 -1
  2. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/pyproject.toml +1 -1
  3. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/config.py +4 -3
  4. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/models.py +2 -2
  5. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/pipeline.py +84 -21
  6. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/stages/llm_classifier.py +64 -19
  7. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/utils/html_normalizer.py +65 -3
  8. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/README.md +0 -0
  9. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/__init__.py +0 -0
  10. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/cache/__init__.py +0 -0
  11. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/cache/l1_cache.py +0 -0
  12. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/cache/l2_cache.py +0 -0
  13. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/stages/__init__.py +0 -0
  14. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/stages/fingerprint.py +0 -0
  15. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/stages/fuzzy_cluster.py +0 -0
  16. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/stages/rule_based.py +0 -0
  17. {segment_classifier-0.1.0 → segment_classifier-0.1.1}/segment_classifier/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: segment_classifier
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Async segment classifier library
5
5
  Author: Gagandeep Singh
6
6
  Author-email: gagan@innerkore.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "segment_classifier"
3
- version = "0.1.0"
3
+ version = "0.1.1"
4
4
  description = "Async segment classifier library"
5
5
  authors = ["Gagandeep Singh <gagan@innerkore.com>"]
6
6
  readme = "README.md"
@@ -17,9 +17,9 @@ class ModelFeatureConfig(BaseModel):
17
17
  - text_density_ratio (very high or very low = complex)
18
18
  - sibling_count == 0 (one-off sections = complex)
19
19
  """
20
- high_complexity_model: str = "anthropic/claude-opus-4"
21
- standard_model: str = "anthropic/claude-sonnet-4-5"
22
- fast_model: str = "anthropic/claude-haiku-4-5"
20
+ high_complexity_model: str = "high-complexity"
21
+ standard_model: str = "standard"
22
+ fast_model: str = "fast"
23
23
 
24
24
  high_complexity_dom_depth_threshold: int = 6
25
25
  high_complexity_unique_tag_threshold: int = 8
@@ -39,6 +39,7 @@ class ClassifierSettings(BaseSettings):
39
39
  model_config = SettingsConfigDict(env_file=".env", env_prefix="CLASSIFIER_")
40
40
 
41
41
  # LiteLLM
42
+ litellm_config_path: str = "litellm_config.yaml"
42
43
  litellm_api_key: str = ""
43
44
  litellm_batch_size: int = 20 # max segments per LLM batch call
44
45
  litellm_max_concurrent_batches: int = 5
@@ -112,10 +112,10 @@ class ClusterRecord(BaseModel):
112
112
 
113
113
 
114
114
  class LLMClassificationRequest(BaseModel):
115
- """Batch item sent to LLM."""
115
+ """Batch item sent to LLM. Use the provided raw HTML in normalized_html to understand purpose and content."""
116
116
  segment_id: str
117
117
  fingerprint_hash: str
118
- normalized_html: str # skeleton only, no content
118
+ normalized_html: str # raw HTML content of the segment
119
119
  position_hint: SegmentPosition
120
120
  sibling_count: int
121
121
  url_hints: list[str]
@@ -121,29 +121,92 @@ class ClassifierPipeline:
121
121
  # Stage 4: LLM Batch
122
122
  llm_calls_made = 0
123
123
  if pending:
124
- llm_items = [
125
- (seg, fingerprints[seg.segment_id][0], fingerprints[seg.segment_id][1])
126
- for seg in pending
127
- ]
128
- llm_results = await self.llm_classifier.classify_batch(llm_items)
124
+ # 1. Deduplicate by exact fingerprint
125
+ fp_to_segments: dict[str, list[InputSegment]] = {}
126
+ for seg in pending:
127
+ _, fp_hash = fingerprints[seg.segment_id]
128
+ fp_to_segments.setdefault(fp_hash, []).append(seg)
129
+
130
+ unique_fps = list(fp_to_segments.keys())
131
+
132
+ # 2. Group unique fingerprints into fuzzy clusters dynamically
133
+ dynamic_clusters: list[list[str]] = []
134
+ cluster_vectors: list[list[float]] = []
135
+
136
+ for fp_hash in unique_fps:
137
+ rep_seg = fp_to_segments[fp_hash][0]
138
+ normalized, _ = fingerprints[rep_seg.segment_id]
139
+ fp_string = self.fuzzy_stage._build_fingerprint_string(normalized)
140
+ vector = self.fuzzy_stage._vectorize(fp_string)
141
+
142
+ best_cluster_idx = -1
143
+ best_sim = -1.0
144
+ for i, c_vec in enumerate(cluster_vectors):
145
+ # Vectors from TfidfTransformer are L2-normalized, so dot product is cosine similarity
146
+ sim = sum(a * b for a, b in zip(vector, c_vec))
147
+ if sim > best_sim:
148
+ best_sim = sim
149
+ best_cluster_idx = i
150
+
151
+ if best_sim >= self.settings.cache.l2_similarity_threshold:
152
+ dynamic_clusters[best_cluster_idx].append(fp_hash)
153
+ else:
154
+ dynamic_clusters.append([fp_hash])
155
+ cluster_vectors.append(vector)
156
+
157
+ # 3. Prepare LLM items (one representative per dynamic cluster)
158
+ llm_items = []
159
+ for cluster_fps in dynamic_clusters:
160
+ rep_fp = cluster_fps[0]
161
+ rep_seg = fp_to_segments[rep_fp][0]
162
+ normalized = fingerprints[rep_seg.segment_id][0]
163
+ llm_items.append((rep_seg, normalized, rep_fp))
129
164
 
130
- # For each LLM result, register in L1 + L2
131
- for seg, result in zip(pending, llm_results):
132
- normalized, fp_hash = fingerprints[seg.segment_id]
133
- await self.l1_cache.set(fp_hash, FingerprintRecord(
134
- fingerprint_hash=fp_hash,
135
- component_type=result.component_type,
136
- confidence=result.confidence,
137
- example_segment_id=seg.segment_id
138
- ))
139
- await self.fuzzy_stage.register(
140
- fingerprint_hash=fp_hash,
141
- normalized=normalized,
142
- component_type=result.component_type,
143
- confidence=result.confidence
144
- )
165
+ llm_results = await self.llm_classifier.classify_batch(llm_items)
145
166
 
146
- classified.extend(llm_results)
167
+ # 4. Apply results to all segments in the dynamic clusters
168
+ for cluster_fps, result in zip(dynamic_clusters, llm_results):
169
+ rep_fp = cluster_fps[0]
170
+
171
+ for fp_hash in cluster_fps:
172
+ group_rep_seg = fp_to_segments[fp_hash][0]
173
+ normalized = fingerprints[group_rep_seg.segment_id][0]
174
+
175
+ # Fuzzy match penalty if not the exact representative fingerprint
176
+ confidence = result.confidence if fp_hash == rep_fp else max(0.0, result.confidence - 0.05)
177
+ stage = result.classification_stage if fp_hash == rep_fp else ClassificationStage.L2_FUZZY_CACHE
178
+
179
+ # Register in caches
180
+ await self.l1_cache.set(fp_hash, FingerprintRecord(
181
+ fingerprint_hash=fp_hash,
182
+ component_type=result.component_type,
183
+ confidence=confidence,
184
+ example_segment_id=group_rep_seg.segment_id
185
+ ))
186
+ await self.fuzzy_stage.register(
187
+ fingerprint_hash=fp_hash,
188
+ normalized=normalized,
189
+ component_type=result.component_type,
190
+ confidence=confidence
191
+ )
192
+
193
+ # Create ClassifiedSegment for all input segments sharing this fp_hash
194
+ for seg in fp_to_segments[fp_hash]:
195
+ classified.append(ClassifiedSegment(
196
+ segment_id=seg.segment_id,
197
+ page_url=seg.page_url,
198
+ page_slug=seg.page_slug,
199
+ raw_html=seg.raw_html,
200
+ text_content=seg.text_content,
201
+ position_hint=seg.position_hint,
202
+ component_type=result.component_type,
203
+ classification_stage=stage,
204
+ confidence=confidence,
205
+ fingerprint_hash=fp_hash,
206
+ cluster_id=result.cluster_id,
207
+ llm_model_used=result.llm_model_used,
208
+ llm_raw_response=result.llm_raw_response
209
+ ))
147
210
 
148
211
  # Calculate total LLM batch calls
149
212
  grouped_by_model: dict[str, int] = {}
@@ -1,7 +1,11 @@
1
1
  import asyncio
2
2
  import json
3
3
  import logging
4
+ import os
5
+ import re
6
+ import yaml
4
7
  import litellm
8
+ from litellm import Router
5
9
  from typing import Any
6
10
  from segment_classifier.models import (
7
11
  InputSegment, ClassifiedSegment, LLMClassificationRequest,
@@ -25,6 +29,20 @@ class LLMBatchClassifier:
25
29
  if settings.litellm_api_key:
26
30
  litellm.api_key = settings.litellm_api_key
27
31
 
32
+ # Initialize LiteLLM Router if config exists
33
+ self.router = None
34
+ if settings.litellm_config_path and os.path.exists(settings.litellm_config_path):
35
+ try:
36
+ with open(settings.litellm_config_path, "r") as f:
37
+ config = yaml.safe_load(f)
38
+ self.router = Router(
39
+ model_list=config.get("model_list", []),
40
+ **config.get("router_settings", {})
41
+ )
42
+ logger.info(f"Initialized LiteLLM Router with config: {settings.litellm_config_path}")
43
+ except Exception as e:
44
+ logger.error(f"Failed to load LiteLLM config from {settings.litellm_config_path}: {e}")
45
+
28
46
  def select_model(
29
47
  self,
30
48
  normalized: NormalizedSegment,
@@ -55,10 +73,10 @@ class LLMBatchClassifier:
55
73
  fingerprint_hash: str,
56
74
  ) -> LLMClassificationRequest:
57
75
  """Construct LLMClassificationRequest from segment + normalized data."""
58
- return LLMClassificationRequest(
76
+ req = LLMClassificationRequest(
59
77
  segment_id=segment.segment_id,
60
78
  fingerprint_hash=fingerprint_hash,
61
- normalized_html=normalized.skeleton,
79
+ normalized_html=normalized.normalized_html,
62
80
  position_hint=segment.position_hint,
63
81
  sibling_count=segment.sibling_count,
64
82
  url_hints=segment.url_path_segments,
@@ -66,6 +84,7 @@ class LLMBatchClassifier:
66
84
  child_tag_counts=normalized.child_tag_counts,
67
85
  text_density_ratio=normalized.text_density_ratio
68
86
  )
87
+ return req
69
88
 
70
89
  async def _call_litellm(
71
90
  self,
@@ -93,7 +112,7 @@ Available component types:
93
112
  ]
94
113
 
95
114
  Rules:
96
- - Use normalized_html structure only, ignore content values
115
+ - Use the provided raw HTML in normalized_html to understand the component purpose and content
97
116
  - sibling_count >= 3 strongly suggests a collection item
98
117
  - position_hint=top/bottom suggests layout components
99
118
  - url_hints provide page context
@@ -108,29 +127,55 @@ Rules:
108
127
  ]
109
128
 
110
129
  try:
111
- response = await litellm.acompletion(
112
- model=model,
113
- messages=messages,
114
- timeout=self.settings.litellm_timeout_seconds,
115
- )
130
+ if self.router:
131
+ response = await self.router.acompletion(
132
+ model=model,
133
+ messages=messages,
134
+ timeout=self.settings.litellm_timeout_seconds,
135
+ )
136
+ else:
137
+ response = await litellm.acompletion(
138
+ model=model,
139
+ messages=messages,
140
+ timeout=self.settings.litellm_timeout_seconds,
141
+ )
116
142
 
117
143
  # Record usage
118
144
  self._model_usage[model] = self._model_usage.get(model, 0) + 1
119
145
 
120
- raw_response = response.choices[0].message.content
121
- # Strip markdown
122
- raw_response = raw_response.strip()
123
- if raw_response.startswith("```json"):
124
- raw_response = raw_response[7:]
125
- elif raw_response.startswith("```"):
126
- raw_response = raw_response[3:]
127
- if raw_response.endswith("```"):
128
- raw_response = raw_response[:-3]
146
+ raw_response = response.choices[0].message.content.strip()
147
+
148
+ # Robust JSON extraction
149
+ json_str = raw_response
150
+ if "```" in json_str:
151
+ # Try to extract from markdown blocks
152
+ blocks = re.findall(r'```(?:json)?\s*(.*?)\s*```', json_str, re.DOTALL)
153
+ if blocks:
154
+ json_str = blocks[0]
155
+
156
+ # If still not parsing, try to find the first [ and last ]
157
+ try:
158
+ parsed = json.loads(json_str)
159
+ except json.JSONDecodeError:
160
+ start = json_str.find('[')
161
+ end = json_str.rfind(']')
162
+ if start != -1 and end != -1:
163
+ try:
164
+ parsed = json.loads(json_str[start:end+1])
165
+ except:
166
+ raise ValueError(f"Could not parse LLM response as JSON: {raw_response[:200]}...")
167
+ else:
168
+ raise ValueError(f"No JSON array found in LLM response: {raw_response[:200]}...")
129
169
 
130
- parsed = json.loads(raw_response.strip())
170
+ if not isinstance(parsed, list):
171
+ raise ValueError(f"LLM response is not a JSON array: {type(parsed)}")
131
172
 
132
173
  results = []
133
174
  for item in parsed:
175
+ if not isinstance(item, dict):
176
+ logger.warning(f"Skipping non-dict item in LLM response: {item}")
177
+ continue
178
+
134
179
  try:
135
180
  results.append(LLMClassificationResult.model_validate(item))
136
181
  except Exception as e:
@@ -139,7 +184,7 @@ Rules:
139
184
  segment_id=item.get("segment_id", ""),
140
185
  component_type=ComponentType.UNKNOWN,
141
186
  confidence=0.0,
142
- reasoning=f"Parse error: {e}"
187
+ reasoning=f"Validation error: {e}"
143
188
  ))
144
189
 
145
190
  # Ensure all segments are accounted for
@@ -16,10 +16,11 @@ STRUCTURAL_CLASS_PATTERN = re.compile(
16
16
  PRESENTATIONAL_CLASS_PATTERN = re.compile(
17
17
  r'\b(mt|mb|ml|mr|mx|my|pt|pb|pl|pr|px|py|w-|h-|text-|bg-|'
18
18
  r'border|rounded|shadow|flex|grid-cols|gap|p-|m-|font-|'
19
- r'color|opacity|z-|hidden|block|inline)\b'
19
+ r'color|opacity|z-|hidden|block|inline)\b',
20
+ re.IGNORECASE
20
21
  )
21
22
 
22
- STRUCTURAL_ATTRS = {"role", "type", "aria-label", "aria-role", "data-component", "data-type"}
23
+ STRUCTURAL_ATTRS = {"role", "type", "aria-label", "aria-role", "data-component", "data-type", "href", "src", "alt", "title", "placeholder"}
23
24
 
24
25
 
25
26
  @dataclass
@@ -32,6 +33,7 @@ class NormalizedSegment:
32
33
  root_tag: str
33
34
  text_density_ratio: float
34
35
  unique_tag_count: int
36
+ normalized_html: str = ""
35
37
 
36
38
  def fingerprint_hash(self) -> str:
37
39
  payload = {
@@ -46,6 +48,16 @@ class NormalizedSegment:
46
48
  json.dumps(payload, sort_keys=True).encode()
47
49
  ).hexdigest()
48
50
 
51
+ def to_normalized_html(self, max_depth: int = 8) -> str:
52
+ """
53
+ Generates a clean, structural HTML string for LLM consumption.
54
+ Strips text and keeps only structural tags/attributes.
55
+ """
56
+ # This is a bit tricky because the dataclass doesn't store the full tree.
57
+ # But we can reconstruct a representative HTML from the skeleton or
58
+ # better yet, modify the normalizer to produce this during the initial walk.
59
+ return self.skeleton # Fallback for now, we'll improve the producer.
60
+
49
61
 
50
62
  def normalize_segment(html: str, text_content: str) -> NormalizedSegment:
51
63
  """
@@ -54,9 +66,10 @@ def normalize_segment(html: str, text_content: str) -> NormalizedSegment:
54
66
  soup = BeautifulSoup(html, "html.parser")
55
67
  root = soup.find()
56
68
  if not root or not isinstance(root, Tag):
57
- return NormalizedSegment("", "", [], {}, 0, "unknown", 0.0, 0)
69
+ return NormalizedSegment("", "", [], {}, 0, "unknown", 0.0, 0, "")
58
70
 
59
71
  skeleton = _extract_skeleton(root)
72
+ normalized_html = _generate_normalized_html(root)
60
73
  attrs_fp = _extract_attrs_fingerprint(root)
61
74
  class_tokens = _extract_class_tokens(root)
62
75
  child_counts = _count_tags(root)
@@ -73,6 +86,7 @@ def normalize_segment(html: str, text_content: str) -> NormalizedSegment:
73
86
  root_tag=root.name,
74
87
  text_density_ratio=round(text_ratio, 4),
75
88
  unique_tag_count=unique_tags,
89
+ normalized_html=normalized_html,
76
90
  )
77
91
 
78
92
 
@@ -92,6 +106,54 @@ def _extract_skeleton(tag: Tag, depth: int = 0, max_depth: int = 8) -> str:
92
106
  return f"{tag.name}>" + "+".join(child_skeletons)
93
107
 
94
108
 
109
+ def _generate_normalized_html(tag: Tag, depth: int = 0, max_depth: int = 10) -> str:
110
+ """Recursive tag-only HTML with structural classes/attributes."""
111
+ if depth >= max_depth:
112
+ return ""
113
+
114
+ tag_name = tag.name
115
+ if tag_name in {"script", "style", "meta", "link", "noscript", "svg", "path", "circle", "rect", "line", "polyline", "polygon", "ellipse"}:
116
+ # Simplification: skip heavy SVG/non-visible tags
117
+ return ""
118
+
119
+ attrs = []
120
+
121
+ # Keep structural attributes
122
+ for attr in STRUCTURAL_ATTRS:
123
+ val = tag.get(attr)
124
+ if val:
125
+ if isinstance(val, list):
126
+ val = " ".join(val)
127
+ attrs.append(f'{attr}="{val}"')
128
+
129
+ # Keep structural classes
130
+ classes = tag.get("class", [])
131
+ if isinstance(classes, str):
132
+ classes = [classes]
133
+ relevant_classes = [c for c in classes if STRUCTURAL_CLASS_PATTERN.search(c) or PRESENTATIONAL_CLASS_PATTERN.search(c)]
134
+ if relevant_classes:
135
+ attrs.append(f'class="{" ".join(relevant_classes)}"')
136
+
137
+ attr_str = " " + " ".join(attrs) if attrs else ""
138
+
139
+ children_html = ""
140
+ for child in tag.children:
141
+ if isinstance(child, Tag):
142
+ children_html += _generate_normalized_html(child, depth + 1, max_depth)
143
+ elif isinstance(child, str):
144
+ text = child.strip()
145
+ if text:
146
+ if len(text) > 200:
147
+ text = text[:197] + "..."
148
+ children_html += text
149
+
150
+ if not children_html:
151
+ # Self-closing for empty tags is fine, or keep them open if preferred
152
+ return f"<{tag_name}{attr_str}></{tag_name}>"
153
+
154
+ return f"<{tag_name}{attr_str}>{children_html}</{tag_name}>"
155
+
156
+
95
157
  def _extract_attrs_fingerprint(tag: Tag) -> str:
96
158
  """Walk all tags, keep only STRUCTURAL_ATTRS values and href/src presence booleans."""
97
159
  parts = []