rnsr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnsr/__init__.py +118 -0
- rnsr/__main__.py +242 -0
- rnsr/agent/__init__.py +218 -0
- rnsr/agent/cross_doc_navigator.py +767 -0
- rnsr/agent/graph.py +1557 -0
- rnsr/agent/llm_cache.py +575 -0
- rnsr/agent/navigator_api.py +497 -0
- rnsr/agent/provenance.py +772 -0
- rnsr/agent/query_clarifier.py +617 -0
- rnsr/agent/reasoning_memory.py +736 -0
- rnsr/agent/repl_env.py +709 -0
- rnsr/agent/rlm_navigator.py +2108 -0
- rnsr/agent/self_reflection.py +602 -0
- rnsr/agent/variable_store.py +308 -0
- rnsr/benchmarks/__init__.py +118 -0
- rnsr/benchmarks/comprehensive_benchmark.py +733 -0
- rnsr/benchmarks/evaluation_suite.py +1210 -0
- rnsr/benchmarks/finance_bench.py +147 -0
- rnsr/benchmarks/pdf_merger.py +178 -0
- rnsr/benchmarks/performance.py +321 -0
- rnsr/benchmarks/quality.py +321 -0
- rnsr/benchmarks/runner.py +298 -0
- rnsr/benchmarks/standard_benchmarks.py +995 -0
- rnsr/client.py +560 -0
- rnsr/document_store.py +394 -0
- rnsr/exceptions.py +74 -0
- rnsr/extraction/__init__.py +172 -0
- rnsr/extraction/candidate_extractor.py +357 -0
- rnsr/extraction/entity_extractor.py +581 -0
- rnsr/extraction/entity_linker.py +825 -0
- rnsr/extraction/grounded_extractor.py +722 -0
- rnsr/extraction/learned_types.py +599 -0
- rnsr/extraction/models.py +232 -0
- rnsr/extraction/relationship_extractor.py +600 -0
- rnsr/extraction/relationship_patterns.py +511 -0
- rnsr/extraction/relationship_validator.py +392 -0
- rnsr/extraction/rlm_extractor.py +589 -0
- rnsr/extraction/rlm_unified_extractor.py +990 -0
- rnsr/extraction/tot_validator.py +610 -0
- rnsr/extraction/unified_extractor.py +342 -0
- rnsr/indexing/__init__.py +60 -0
- rnsr/indexing/knowledge_graph.py +1128 -0
- rnsr/indexing/kv_store.py +313 -0
- rnsr/indexing/persistence.py +323 -0
- rnsr/indexing/semantic_retriever.py +237 -0
- rnsr/indexing/semantic_search.py +320 -0
- rnsr/indexing/skeleton_index.py +395 -0
- rnsr/ingestion/__init__.py +161 -0
- rnsr/ingestion/chart_parser.py +569 -0
- rnsr/ingestion/document_boundary.py +662 -0
- rnsr/ingestion/font_histogram.py +334 -0
- rnsr/ingestion/header_classifier.py +595 -0
- rnsr/ingestion/hierarchical_cluster.py +515 -0
- rnsr/ingestion/layout_detector.py +356 -0
- rnsr/ingestion/layout_model.py +379 -0
- rnsr/ingestion/ocr_fallback.py +177 -0
- rnsr/ingestion/pipeline.py +936 -0
- rnsr/ingestion/semantic_fallback.py +417 -0
- rnsr/ingestion/table_parser.py +799 -0
- rnsr/ingestion/text_builder.py +460 -0
- rnsr/ingestion/tree_builder.py +402 -0
- rnsr/ingestion/vision_retrieval.py +965 -0
- rnsr/ingestion/xy_cut.py +555 -0
- rnsr/llm.py +733 -0
- rnsr/models.py +167 -0
- rnsr/py.typed +2 -0
- rnsr-0.1.0.dist-info/METADATA +592 -0
- rnsr-0.1.0.dist-info/RECORD +72 -0
- rnsr-0.1.0.dist-info/WHEEL +5 -0
- rnsr-0.1.0.dist-info/entry_points.txt +2 -0
- rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
- rnsr-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,722 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RNSR Grounded Entity Extractor
|
|
3
|
+
|
|
4
|
+
Implements the RLM pattern for entity extraction:
|
|
5
|
+
1. Pre-extract candidates using regex/patterns (CODE FIRST)
|
|
6
|
+
2. LLM classifies and validates candidates (LLM SECOND)
|
|
7
|
+
3. Recursive refinement if needed
|
|
8
|
+
|
|
9
|
+
This approach PREVENTS HALLUCINATION because:
|
|
10
|
+
- Every entity is tied to an exact text span
|
|
11
|
+
- LLM classifies existing text, doesn't generate entities
|
|
12
|
+
- Pattern matching provides grounded candidates
|
|
13
|
+
- LLM's job is validation, not invention
|
|
14
|
+
|
|
15
|
+
Validation Modes:
|
|
16
|
+
- SIMPLE: Basic LLM classification (faster, single call)
|
|
17
|
+
- TOT: Tree of Thoughts validation with probabilities and navigation (more accurate)
|
|
18
|
+
|
|
19
|
+
Inspired by the RLM paper's insight: use code to filter/extract
|
|
20
|
+
before sending to LLM for reasoning.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import json
|
|
26
|
+
import re
|
|
27
|
+
import time
|
|
28
|
+
from enum import Enum
|
|
29
|
+
from typing import Any, TYPE_CHECKING
|
|
30
|
+
|
|
31
|
+
import structlog
|
|
32
|
+
|
|
33
|
+
from rnsr.extraction.models import (
|
|
34
|
+
Entity,
|
|
35
|
+
EntityType,
|
|
36
|
+
ExtractionResult,
|
|
37
|
+
Mention,
|
|
38
|
+
)
|
|
39
|
+
from rnsr.extraction.candidate_extractor import (
|
|
40
|
+
CandidateExtractor,
|
|
41
|
+
EntityCandidate,
|
|
42
|
+
)
|
|
43
|
+
from rnsr.extraction.learned_types import (
|
|
44
|
+
get_learned_type_registry,
|
|
45
|
+
)
|
|
46
|
+
from rnsr.llm import get_llm
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from rnsr.models import DocumentTree
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ValidationMode(str, Enum):
|
|
53
|
+
"""Validation mode for candidate entities."""
|
|
54
|
+
|
|
55
|
+
SIMPLE = "simple" # Basic LLM classification (faster)
|
|
56
|
+
TOT = "tot" # Tree of Thoughts with navigation (more accurate)
|
|
57
|
+
|
|
58
|
+
logger = structlog.get_logger(__name__)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# LLM prompt for CLASSIFYING pre-extracted candidates (not generating)
|
|
62
|
+
CLASSIFICATION_PROMPT = """You are classifying entity candidates that have already been extracted from a document.
|
|
63
|
+
|
|
64
|
+
Your job is to VALIDATE and CLASSIFY each candidate - NOT to generate new entities.
|
|
65
|
+
These candidates were extracted by pattern matching and are grounded in the actual text.
|
|
66
|
+
|
|
67
|
+
Document Section:
|
|
68
|
+
---
|
|
69
|
+
{content}
|
|
70
|
+
---
|
|
71
|
+
|
|
72
|
+
Pre-extracted candidates to classify:
|
|
73
|
+
{candidates_json}
|
|
74
|
+
|
|
75
|
+
For each candidate, provide:
|
|
76
|
+
1. valid: true if this is a real entity worth tracking, false if it's noise
|
|
77
|
+
2. type: The entity type (PERSON, ORGANIZATION, DATE, EVENT, LEGAL_CONCEPT, LOCATION, REFERENCE, MONETARY, DOCUMENT, or other descriptive type)
|
|
78
|
+
3. canonical_name: Normalized/cleaned name (e.g., "Mr. John Smith" → "John Smith")
|
|
79
|
+
4. role: Any role or relationship mentioned (e.g., "defendant", "CEO")
|
|
80
|
+
|
|
81
|
+
Return JSON array:
|
|
82
|
+
```json
|
|
83
|
+
[
|
|
84
|
+
{{
|
|
85
|
+
"candidate_id": 0,
|
|
86
|
+
"valid": true,
|
|
87
|
+
"type": "PERSON",
|
|
88
|
+
"canonical_name": "John Smith",
|
|
89
|
+
"role": "defendant"
|
|
90
|
+
}},
|
|
91
|
+
{{
|
|
92
|
+
"candidate_id": 1,
|
|
93
|
+
"valid": false,
|
|
94
|
+
"reason": "Generic reference, not a specific entity"
|
|
95
|
+
}}
|
|
96
|
+
]
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
Rules:
|
|
100
|
+
- ONLY classify the candidates provided - do not add new entities
|
|
101
|
+
- Set valid=false for generic terms, partial matches, or noise
|
|
102
|
+
- Use the exact text span provided - don't modify the match boundaries
|
|
103
|
+
- Be conservative - when uncertain, set valid=false
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
# Prompt for finding entities the patterns might have missed
|
|
107
|
+
SUPPLEMENTARY_PROMPT = """Review this text for important entities that might have been missed by pattern matching.
|
|
108
|
+
|
|
109
|
+
Text:
|
|
110
|
+
---
|
|
111
|
+
{content}
|
|
112
|
+
---
|
|
113
|
+
|
|
114
|
+
Already extracted: {existing_entities}
|
|
115
|
+
|
|
116
|
+
Are there any CLEARLY IDENTIFIABLE entities that were missed?
|
|
117
|
+
Only list entities that:
|
|
118
|
+
1. Are explicitly named in the text (not implied)
|
|
119
|
+
2. Are significant (not passing mentions)
|
|
120
|
+
3. Have a clear type
|
|
121
|
+
|
|
122
|
+
If there are missed entities, return:
|
|
123
|
+
```json
|
|
124
|
+
[
|
|
125
|
+
{{
|
|
126
|
+
"text": "exact text as it appears",
|
|
127
|
+
"type": "ENTITY_TYPE",
|
|
128
|
+
"canonical_name": "Normalized Name",
|
|
129
|
+
"reason": "Why this is important"
|
|
130
|
+
}}
|
|
131
|
+
]
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
If nothing important was missed, return: []
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class GroundedEntityExtractor:
|
|
139
|
+
"""
|
|
140
|
+
Entity extractor following the RLM pattern:
|
|
141
|
+
CODE FIRST (pattern extraction) → LLM SECOND (classification).
|
|
142
|
+
|
|
143
|
+
This prevents hallucination by grounding entities in actual text.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
llm: Any | None = None,
|
|
149
|
+
candidate_extractor: CandidateExtractor | None = None,
|
|
150
|
+
min_content_length: int = 50,
|
|
151
|
+
max_candidates_per_batch: int = 30,
|
|
152
|
+
enable_supplementary_extraction: bool = True,
|
|
153
|
+
enable_type_learning: bool = True,
|
|
154
|
+
validation_mode: ValidationMode | str = ValidationMode.SIMPLE,
|
|
155
|
+
tot_selection_threshold: float = 0.6,
|
|
156
|
+
tot_enable_navigation: bool = True,
|
|
157
|
+
):
|
|
158
|
+
"""
|
|
159
|
+
Initialize the grounded extractor.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
llm: LLM instance. If None, uses get_llm().
|
|
163
|
+
candidate_extractor: Pre-extraction engine.
|
|
164
|
+
min_content_length: Minimum content length to process.
|
|
165
|
+
max_candidates_per_batch: Max candidates per LLM call.
|
|
166
|
+
enable_supplementary_extraction: Check for missed entities.
|
|
167
|
+
enable_type_learning: Learn new entity types.
|
|
168
|
+
validation_mode: SIMPLE (faster) or TOT (more accurate with navigation).
|
|
169
|
+
tot_selection_threshold: Probability threshold for ToT mode.
|
|
170
|
+
tot_enable_navigation: Navigate tree for uncertain candidates in ToT mode.
|
|
171
|
+
"""
|
|
172
|
+
self.llm = llm
|
|
173
|
+
self.candidate_extractor = candidate_extractor or CandidateExtractor()
|
|
174
|
+
self.min_content_length = min_content_length
|
|
175
|
+
self.max_candidates_per_batch = max_candidates_per_batch
|
|
176
|
+
self.enable_supplementary_extraction = enable_supplementary_extraction
|
|
177
|
+
self.enable_type_learning = enable_type_learning
|
|
178
|
+
|
|
179
|
+
# Validation mode
|
|
180
|
+
if isinstance(validation_mode, str):
|
|
181
|
+
validation_mode = ValidationMode(validation_mode.lower())
|
|
182
|
+
self.validation_mode = validation_mode
|
|
183
|
+
self.tot_selection_threshold = tot_selection_threshold
|
|
184
|
+
self.tot_enable_navigation = tot_enable_navigation
|
|
185
|
+
|
|
186
|
+
# Lazy init for ToT validator
|
|
187
|
+
self._tot_validator = None
|
|
188
|
+
|
|
189
|
+
# Lazy LLM init
|
|
190
|
+
self._llm_initialized = False
|
|
191
|
+
|
|
192
|
+
# Type registry for learning
|
|
193
|
+
self._type_registry = get_learned_type_registry() if enable_type_learning else None
|
|
194
|
+
|
|
195
|
+
# Cache
|
|
196
|
+
self._cache: dict[str, list[Entity]] = {}
|
|
197
|
+
|
|
198
|
+
def _get_llm(self) -> Any:
|
|
199
|
+
"""Get or initialize LLM."""
|
|
200
|
+
if self.llm is None and not self._llm_initialized:
|
|
201
|
+
self.llm = get_llm()
|
|
202
|
+
self._llm_initialized = True
|
|
203
|
+
return self.llm
|
|
204
|
+
|
|
205
|
+
def _get_tot_validator(self) -> "TotEntityValidator":
|
|
206
|
+
"""Get or initialize ToT validator."""
|
|
207
|
+
if self._tot_validator is None:
|
|
208
|
+
from rnsr.extraction.tot_validator import TotEntityValidator
|
|
209
|
+
self._tot_validator = TotEntityValidator(
|
|
210
|
+
llm=self._get_llm(),
|
|
211
|
+
selection_threshold=self.tot_selection_threshold,
|
|
212
|
+
enable_navigation=self.tot_enable_navigation,
|
|
213
|
+
max_candidates_per_batch=self.max_candidates_per_batch,
|
|
214
|
+
)
|
|
215
|
+
return self._tot_validator
|
|
216
|
+
|
|
217
|
+
def extract_from_node(
|
|
218
|
+
self,
|
|
219
|
+
node_id: str,
|
|
220
|
+
doc_id: str,
|
|
221
|
+
header: str,
|
|
222
|
+
content: str,
|
|
223
|
+
page_num: int | None = None,
|
|
224
|
+
document_tree: "DocumentTree | None" = None,
|
|
225
|
+
) -> ExtractionResult:
|
|
226
|
+
"""
|
|
227
|
+
Extract entities using the grounded approach.
|
|
228
|
+
|
|
229
|
+
Flow:
|
|
230
|
+
1. Pattern-extract candidates (grounded)
|
|
231
|
+
2. Validate with LLM (SIMPLE mode) or ToT (TOT mode)
|
|
232
|
+
3. Optionally check for missed entities
|
|
233
|
+
4. Return validated entities
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
node_id: Section node ID.
|
|
237
|
+
doc_id: Document ID.
|
|
238
|
+
header: Section header.
|
|
239
|
+
content: Section content.
|
|
240
|
+
page_num: Page number.
|
|
241
|
+
document_tree: Optional tree for ToT navigation.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
ExtractionResult with grounded entities.
|
|
245
|
+
"""
|
|
246
|
+
start_time = time.time()
|
|
247
|
+
result = ExtractionResult(
|
|
248
|
+
node_id=node_id,
|
|
249
|
+
doc_id=doc_id,
|
|
250
|
+
extraction_method=f"grounded_{self.validation_mode.value}",
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Skip short content
|
|
254
|
+
if len(content.strip()) < self.min_content_length:
|
|
255
|
+
return result
|
|
256
|
+
|
|
257
|
+
# Check cache
|
|
258
|
+
cache_key = f"{doc_id}:{node_id}"
|
|
259
|
+
if cache_key in self._cache:
|
|
260
|
+
result.entities = self._cache[cache_key]
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
# STEP 1: Extract candidates using patterns (CODE FIRST)
|
|
264
|
+
candidates = self.candidate_extractor.extract_candidates(content)
|
|
265
|
+
|
|
266
|
+
result.warnings.append(f"Pattern extraction found {len(candidates)} candidates")
|
|
267
|
+
|
|
268
|
+
if not candidates:
|
|
269
|
+
# No pattern matches - try supplementary if enabled
|
|
270
|
+
if self.enable_supplementary_extraction:
|
|
271
|
+
entities = self._supplementary_extraction(
|
|
272
|
+
content, node_id, doc_id, page_num, []
|
|
273
|
+
)
|
|
274
|
+
result.entities = entities
|
|
275
|
+
|
|
276
|
+
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
# STEP 2: Validate candidates (LLM SECOND)
|
|
280
|
+
if self.validation_mode == ValidationMode.TOT:
|
|
281
|
+
# Use Tree of Thoughts validation (more accurate, can navigate)
|
|
282
|
+
entities = self._validate_with_tot(
|
|
283
|
+
candidates=candidates,
|
|
284
|
+
header=header,
|
|
285
|
+
content=content,
|
|
286
|
+
node_id=node_id,
|
|
287
|
+
doc_id=doc_id,
|
|
288
|
+
page_num=page_num,
|
|
289
|
+
document_tree=document_tree,
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
# Use simple classification (faster)
|
|
293
|
+
entities = self._classify_candidates(
|
|
294
|
+
candidates=candidates,
|
|
295
|
+
content=content,
|
|
296
|
+
node_id=node_id,
|
|
297
|
+
doc_id=doc_id,
|
|
298
|
+
page_num=page_num,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# STEP 3: Check for missed entities (optional)
|
|
302
|
+
if self.enable_supplementary_extraction:
|
|
303
|
+
existing_names = [e.canonical_name for e in entities]
|
|
304
|
+
supplementary = self._supplementary_extraction(
|
|
305
|
+
content, node_id, doc_id, page_num, existing_names
|
|
306
|
+
)
|
|
307
|
+
entities.extend(supplementary)
|
|
308
|
+
|
|
309
|
+
result.entities = entities
|
|
310
|
+
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
311
|
+
|
|
312
|
+
# Cache
|
|
313
|
+
self._cache[cache_key] = entities
|
|
314
|
+
|
|
315
|
+
logger.info(
|
|
316
|
+
"grounded_extraction_complete",
|
|
317
|
+
node_id=node_id,
|
|
318
|
+
candidates=len(candidates),
|
|
319
|
+
entities=len(entities),
|
|
320
|
+
validation_mode=self.validation_mode.value,
|
|
321
|
+
time_ms=result.processing_time_ms,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
def _validate_with_tot(
|
|
327
|
+
self,
|
|
328
|
+
candidates: list[EntityCandidate],
|
|
329
|
+
header: str,
|
|
330
|
+
content: str,
|
|
331
|
+
node_id: str,
|
|
332
|
+
doc_id: str,
|
|
333
|
+
page_num: int | None,
|
|
334
|
+
document_tree: "DocumentTree | None",
|
|
335
|
+
) -> list[Entity]:
|
|
336
|
+
"""
|
|
337
|
+
Validate candidates using Tree of Thoughts pattern.
|
|
338
|
+
|
|
339
|
+
This uses the same ToT approach as document navigation:
|
|
340
|
+
- Evaluate each candidate with probability + reasoning
|
|
341
|
+
- Navigate to related sections for uncertain candidates
|
|
342
|
+
- More accurate than simple classification
|
|
343
|
+
"""
|
|
344
|
+
tot_validator = self._get_tot_validator()
|
|
345
|
+
|
|
346
|
+
# Run ToT validation
|
|
347
|
+
validation_result = tot_validator.validate_candidates(
|
|
348
|
+
candidates=candidates,
|
|
349
|
+
section_header=header,
|
|
350
|
+
section_content=content,
|
|
351
|
+
document_tree=document_tree,
|
|
352
|
+
node_id=node_id,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Convert to entities
|
|
356
|
+
entities = tot_validator.candidates_to_entities(
|
|
357
|
+
candidates=candidates,
|
|
358
|
+
validation_result=validation_result,
|
|
359
|
+
node_id=node_id,
|
|
360
|
+
doc_id=doc_id,
|
|
361
|
+
page_num=page_num,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Learn from OTHER types
|
|
365
|
+
if self._type_registry:
|
|
366
|
+
for entity in entities:
|
|
367
|
+
if entity.type == EntityType.OTHER:
|
|
368
|
+
original_type = entity.metadata.get("original_type", "unknown")
|
|
369
|
+
context = entity.mentions[0].context if entity.mentions else ""
|
|
370
|
+
self._type_registry.record_type(
|
|
371
|
+
type_name=original_type,
|
|
372
|
+
context=context,
|
|
373
|
+
entity_name=entity.canonical_name,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
logger.info(
|
|
377
|
+
"tot_validation_complete",
|
|
378
|
+
candidates=len(candidates),
|
|
379
|
+
validated=len(entities),
|
|
380
|
+
high_confidence=validation_result.high_confidence_count,
|
|
381
|
+
low_confidence=validation_result.low_confidence_count,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return entities
|
|
385
|
+
|
|
386
|
+
def _classify_candidates(
|
|
387
|
+
self,
|
|
388
|
+
candidates: list[EntityCandidate],
|
|
389
|
+
content: str,
|
|
390
|
+
node_id: str,
|
|
391
|
+
doc_id: str,
|
|
392
|
+
page_num: int | None,
|
|
393
|
+
) -> list[Entity]:
|
|
394
|
+
"""
|
|
395
|
+
Use LLM to classify pre-extracted candidates.
|
|
396
|
+
|
|
397
|
+
The LLM's job is VALIDATION and CLASSIFICATION,
|
|
398
|
+
not generation of new entities.
|
|
399
|
+
"""
|
|
400
|
+
llm = self._get_llm()
|
|
401
|
+
if llm is None:
|
|
402
|
+
# No LLM - return candidates as-is with pattern-based types
|
|
403
|
+
return self._candidates_to_entities(
|
|
404
|
+
candidates, node_id, doc_id, page_num
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
entities = []
|
|
408
|
+
|
|
409
|
+
# Process in batches
|
|
410
|
+
for i in range(0, len(candidates), self.max_candidates_per_batch):
|
|
411
|
+
batch = candidates[i:i + self.max_candidates_per_batch]
|
|
412
|
+
batch_entities = self._classify_batch(
|
|
413
|
+
batch, content, node_id, doc_id, page_num
|
|
414
|
+
)
|
|
415
|
+
entities.extend(batch_entities)
|
|
416
|
+
|
|
417
|
+
return entities
|
|
418
|
+
|
|
419
|
+
def _classify_batch(
|
|
420
|
+
self,
|
|
421
|
+
candidates: list[EntityCandidate],
|
|
422
|
+
content: str,
|
|
423
|
+
node_id: str,
|
|
424
|
+
doc_id: str,
|
|
425
|
+
page_num: int | None,
|
|
426
|
+
) -> list[Entity]:
|
|
427
|
+
"""Classify a batch of candidates with LLM."""
|
|
428
|
+
# Format candidates for prompt
|
|
429
|
+
candidates_json = json.dumps([
|
|
430
|
+
{
|
|
431
|
+
"id": idx,
|
|
432
|
+
"text": c.text,
|
|
433
|
+
"type_hint": c.candidate_type,
|
|
434
|
+
"context": c.context[:150],
|
|
435
|
+
}
|
|
436
|
+
for idx, c in enumerate(candidates)
|
|
437
|
+
], indent=2)
|
|
438
|
+
|
|
439
|
+
prompt = CLASSIFICATION_PROMPT.format(
|
|
440
|
+
content=content[:3000], # Limit content size
|
|
441
|
+
candidates_json=candidates_json,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
response = self.llm.complete(prompt)
|
|
446
|
+
response_text = str(response) if not isinstance(response, str) else response
|
|
447
|
+
|
|
448
|
+
# Parse classifications
|
|
449
|
+
classifications = self._parse_classification_response(response_text)
|
|
450
|
+
|
|
451
|
+
# Convert to entities
|
|
452
|
+
entities = []
|
|
453
|
+
for classification in classifications:
|
|
454
|
+
candidate_id = classification.get("candidate_id")
|
|
455
|
+
|
|
456
|
+
if candidate_id is None or candidate_id >= len(candidates):
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
if not classification.get("valid", False):
|
|
460
|
+
continue
|
|
461
|
+
|
|
462
|
+
candidate = candidates[candidate_id]
|
|
463
|
+
entity = self._create_entity_from_classification(
|
|
464
|
+
candidate=candidate,
|
|
465
|
+
classification=classification,
|
|
466
|
+
node_id=node_id,
|
|
467
|
+
doc_id=doc_id,
|
|
468
|
+
page_num=page_num,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if entity:
|
|
472
|
+
entities.append(entity)
|
|
473
|
+
|
|
474
|
+
return entities
|
|
475
|
+
|
|
476
|
+
except Exception as e:
|
|
477
|
+
logger.warning("classification_failed", error=str(e))
|
|
478
|
+
# Fallback: return candidates as-is
|
|
479
|
+
return self._candidates_to_entities(
|
|
480
|
+
candidates, node_id, doc_id, page_num
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def _parse_classification_response(
|
|
484
|
+
self,
|
|
485
|
+
response_text: str,
|
|
486
|
+
) -> list[dict[str, Any]]:
|
|
487
|
+
"""Parse LLM classification response."""
|
|
488
|
+
# Extract JSON
|
|
489
|
+
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response_text)
|
|
490
|
+
if json_match:
|
|
491
|
+
json_str = json_match.group(1)
|
|
492
|
+
else:
|
|
493
|
+
json_match = re.search(r'\[[\s\S]*\]', response_text)
|
|
494
|
+
json_str = json_match.group(0) if json_match else "[]"
|
|
495
|
+
|
|
496
|
+
try:
|
|
497
|
+
return json.loads(json_str)
|
|
498
|
+
except json.JSONDecodeError:
|
|
499
|
+
return []
|
|
500
|
+
|
|
501
|
+
def _create_entity_from_classification(
|
|
502
|
+
self,
|
|
503
|
+
candidate: EntityCandidate,
|
|
504
|
+
classification: dict[str, Any],
|
|
505
|
+
node_id: str,
|
|
506
|
+
doc_id: str,
|
|
507
|
+
page_num: int | None,
|
|
508
|
+
) -> Entity | None:
|
|
509
|
+
"""Create Entity from candidate + LLM classification."""
|
|
510
|
+
entity_type_str = classification.get("type", candidate.candidate_type).upper()
|
|
511
|
+
|
|
512
|
+
# Map to EntityType
|
|
513
|
+
entity_type = self._map_entity_type(entity_type_str)
|
|
514
|
+
|
|
515
|
+
# Record learned type if OTHER
|
|
516
|
+
if entity_type == EntityType.OTHER and self._type_registry:
|
|
517
|
+
self._type_registry.record_type(
|
|
518
|
+
type_name=entity_type_str.lower(),
|
|
519
|
+
context=candidate.context,
|
|
520
|
+
entity_name=candidate.text,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
# Get canonical name (LLM-cleaned or original)
|
|
524
|
+
canonical_name = classification.get("canonical_name", candidate.text).strip()
|
|
525
|
+
if not canonical_name:
|
|
526
|
+
canonical_name = candidate.text
|
|
527
|
+
|
|
528
|
+
# Build metadata
|
|
529
|
+
metadata = {
|
|
530
|
+
"grounded": True, # Flag that this is grounded in text
|
|
531
|
+
"pattern": candidate.pattern_name,
|
|
532
|
+
"span_start": candidate.start,
|
|
533
|
+
"span_end": candidate.end,
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
if classification.get("role"):
|
|
537
|
+
metadata["role"] = classification["role"]
|
|
538
|
+
|
|
539
|
+
if entity_type == EntityType.OTHER:
|
|
540
|
+
metadata["original_type"] = entity_type_str.lower()
|
|
541
|
+
|
|
542
|
+
# Create mention
|
|
543
|
+
mention = Mention(
|
|
544
|
+
node_id=node_id,
|
|
545
|
+
doc_id=doc_id,
|
|
546
|
+
span_start=candidate.start,
|
|
547
|
+
span_end=candidate.end,
|
|
548
|
+
context=candidate.context,
|
|
549
|
+
page_num=page_num,
|
|
550
|
+
confidence=candidate.confidence,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
return Entity(
|
|
554
|
+
type=entity_type,
|
|
555
|
+
canonical_name=canonical_name,
|
|
556
|
+
aliases=[candidate.text] if candidate.text != canonical_name else [],
|
|
557
|
+
mentions=[mention],
|
|
558
|
+
metadata=metadata,
|
|
559
|
+
source_doc_id=doc_id,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
def _map_entity_type(self, type_str: str) -> EntityType:
|
|
563
|
+
"""Map type string to EntityType enum."""
|
|
564
|
+
type_str = type_str.upper()
|
|
565
|
+
|
|
566
|
+
mapping = {
|
|
567
|
+
"PERSON": EntityType.PERSON,
|
|
568
|
+
"PEOPLE": EntityType.PERSON,
|
|
569
|
+
"INDIVIDUAL": EntityType.PERSON,
|
|
570
|
+
"ORGANIZATION": EntityType.ORGANIZATION,
|
|
571
|
+
"ORG": EntityType.ORGANIZATION,
|
|
572
|
+
"COMPANY": EntityType.ORGANIZATION,
|
|
573
|
+
"COURT": EntityType.ORGANIZATION,
|
|
574
|
+
"DATE": EntityType.DATE,
|
|
575
|
+
"TIME": EntityType.DATE,
|
|
576
|
+
"EVENT": EntityType.EVENT,
|
|
577
|
+
"LEGAL_CONCEPT": EntityType.LEGAL_CONCEPT,
|
|
578
|
+
"LEGAL": EntityType.LEGAL_CONCEPT,
|
|
579
|
+
"LOCATION": EntityType.LOCATION,
|
|
580
|
+
"PLACE": EntityType.LOCATION,
|
|
581
|
+
"ADDRESS": EntityType.LOCATION,
|
|
582
|
+
"REFERENCE": EntityType.REFERENCE,
|
|
583
|
+
"CITATION": EntityType.REFERENCE,
|
|
584
|
+
"MONETARY": EntityType.MONETARY,
|
|
585
|
+
"MONEY": EntityType.MONETARY,
|
|
586
|
+
"DOCUMENT": EntityType.DOCUMENT,
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
try:
|
|
590
|
+
return EntityType(type_str.lower())
|
|
591
|
+
except ValueError:
|
|
592
|
+
return mapping.get(type_str, EntityType.OTHER)
|
|
593
|
+
|
|
594
|
+
def _candidates_to_entities(
|
|
595
|
+
self,
|
|
596
|
+
candidates: list[EntityCandidate],
|
|
597
|
+
node_id: str,
|
|
598
|
+
doc_id: str,
|
|
599
|
+
page_num: int | None,
|
|
600
|
+
) -> list[Entity]:
|
|
601
|
+
"""Convert candidates directly to entities (no LLM)."""
|
|
602
|
+
entities = []
|
|
603
|
+
|
|
604
|
+
for candidate in candidates:
|
|
605
|
+
entity_type = self._map_entity_type(candidate.candidate_type)
|
|
606
|
+
|
|
607
|
+
mention = Mention(
|
|
608
|
+
node_id=node_id,
|
|
609
|
+
doc_id=doc_id,
|
|
610
|
+
span_start=candidate.start,
|
|
611
|
+
span_end=candidate.end,
|
|
612
|
+
context=candidate.context,
|
|
613
|
+
page_num=page_num,
|
|
614
|
+
confidence=candidate.confidence,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
entity = Entity(
|
|
618
|
+
type=entity_type,
|
|
619
|
+
canonical_name=candidate.text,
|
|
620
|
+
mentions=[mention],
|
|
621
|
+
metadata={
|
|
622
|
+
"grounded": True,
|
|
623
|
+
"pattern": candidate.pattern_name,
|
|
624
|
+
"llm_validated": False,
|
|
625
|
+
},
|
|
626
|
+
source_doc_id=doc_id,
|
|
627
|
+
)
|
|
628
|
+
entities.append(entity)
|
|
629
|
+
|
|
630
|
+
return entities
|
|
631
|
+
|
|
632
|
+
def _supplementary_extraction(
|
|
633
|
+
self,
|
|
634
|
+
content: str,
|
|
635
|
+
node_id: str,
|
|
636
|
+
doc_id: str,
|
|
637
|
+
page_num: int | None,
|
|
638
|
+
existing_names: list[str],
|
|
639
|
+
) -> list[Entity]:
|
|
640
|
+
"""
|
|
641
|
+
Check for entities that patterns might have missed.
|
|
642
|
+
|
|
643
|
+
This is a safety net, but the LLM is instructed to be
|
|
644
|
+
conservative and only add clearly identifiable entities.
|
|
645
|
+
"""
|
|
646
|
+
llm = self._get_llm()
|
|
647
|
+
if llm is None:
|
|
648
|
+
return []
|
|
649
|
+
|
|
650
|
+
prompt = SUPPLEMENTARY_PROMPT.format(
|
|
651
|
+
content=content[:2000],
|
|
652
|
+
existing_entities=", ".join(existing_names) if existing_names else "None",
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
try:
|
|
656
|
+
response = llm.complete(prompt)
|
|
657
|
+
response_text = str(response) if not isinstance(response, str) else response
|
|
658
|
+
|
|
659
|
+
# Parse response
|
|
660
|
+
json_match = re.search(r'\[[\s\S]*?\]', response_text)
|
|
661
|
+
if not json_match:
|
|
662
|
+
return []
|
|
663
|
+
|
|
664
|
+
missed = json.loads(json_match.group())
|
|
665
|
+
|
|
666
|
+
if not isinstance(missed, list):
|
|
667
|
+
return []
|
|
668
|
+
|
|
669
|
+
entities = []
|
|
670
|
+
for item in missed:
|
|
671
|
+
text = item.get("text", "").strip()
|
|
672
|
+
if not text or text in existing_names:
|
|
673
|
+
continue
|
|
674
|
+
|
|
675
|
+
entity_type = self._map_entity_type(item.get("type", "OTHER"))
|
|
676
|
+
canonical = item.get("canonical_name", text)
|
|
677
|
+
|
|
678
|
+
# Find the text in content to get position
|
|
679
|
+
match = re.search(re.escape(text), content)
|
|
680
|
+
span_start = match.start() if match else None
|
|
681
|
+
span_end = match.end() if match else None
|
|
682
|
+
|
|
683
|
+
mention = Mention(
|
|
684
|
+
node_id=node_id,
|
|
685
|
+
doc_id=doc_id,
|
|
686
|
+
span_start=span_start,
|
|
687
|
+
span_end=span_end,
|
|
688
|
+
context=content[max(0, (span_start or 0) - 50):(span_end or 0) + 50] if span_start else "",
|
|
689
|
+
page_num=page_num,
|
|
690
|
+
confidence=0.6, # Lower confidence for supplementary
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
entity = Entity(
|
|
694
|
+
type=entity_type,
|
|
695
|
+
canonical_name=canonical,
|
|
696
|
+
aliases=[text] if text != canonical else [],
|
|
697
|
+
mentions=[mention],
|
|
698
|
+
metadata={
|
|
699
|
+
"grounded": span_start is not None,
|
|
700
|
+
"supplementary": True,
|
|
701
|
+
"reason": item.get("reason", ""),
|
|
702
|
+
},
|
|
703
|
+
source_doc_id=doc_id,
|
|
704
|
+
)
|
|
705
|
+
entities.append(entity)
|
|
706
|
+
|
|
707
|
+
if entities:
|
|
708
|
+
logger.debug(
|
|
709
|
+
"supplementary_entities_found",
|
|
710
|
+
count=len(entities),
|
|
711
|
+
names=[e.canonical_name for e in entities],
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return entities
|
|
715
|
+
|
|
716
|
+
except Exception as e:
|
|
717
|
+
logger.debug("supplementary_extraction_failed", error=str(e))
|
|
718
|
+
return []
|
|
719
|
+
|
|
720
|
+
def clear_cache(self) -> None:
|
|
721
|
+
"""Clear the extraction cache."""
|
|
722
|
+
self._cache.clear()
|