corp-extractor 0.5.0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.3.dist-info}/METADATA +228 -30
- corp_extractor-0.9.3.dist-info/RECORD +79 -0
- statement_extractor/__init__.py +1 -1
- statement_extractor/cli.py +2030 -24
- statement_extractor/data/statement_taxonomy.json +6949 -1159
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +428 -0
- statement_extractor/database/importers/__init__.py +32 -0
- statement_extractor/database/importers/companies_house.py +559 -0
- statement_extractor/database/importers/companies_house_officers.py +431 -0
- statement_extractor/database/importers/gleif.py +561 -0
- statement_extractor/database/importers/sec_edgar.py +392 -0
- statement_extractor/database/importers/sec_form4.py +512 -0
- statement_extractor/database/importers/wikidata.py +1120 -0
- statement_extractor/database/importers/wikidata_dump.py +1951 -0
- statement_extractor/database/importers/wikidata_people.py +1130 -0
- statement_extractor/database/models.py +254 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +3034 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +171 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/extractor.py +1 -1
- statement_extractor/models/__init__.py +19 -3
- statement_extractor/models/canonical.py +44 -1
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/labels.py +47 -18
- statement_extractor/models/qualifiers.py +51 -3
- statement_extractor/models/statement.py +39 -15
- statement_extractor/models.py +1 -1
- statement_extractor/pipeline/config.py +6 -11
- statement_extractor/pipeline/context.py +5 -5
- statement_extractor/pipeline/orchestrator.py +90 -121
- statement_extractor/pipeline/registry.py +52 -46
- statement_extractor/plugins/__init__.py +20 -8
- statement_extractor/plugins/base.py +348 -78
- statement_extractor/plugins/extractors/gliner2.py +38 -28
- statement_extractor/plugins/labelers/taxonomy.py +18 -5
- statement_extractor/plugins/labelers/taxonomy_embedding.py +17 -6
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +11 -0
- statement_extractor/plugins/qualifiers/companies_house.py +14 -3
- statement_extractor/plugins/qualifiers/embedding_company.py +422 -0
- statement_extractor/plugins/qualifiers/gleif.py +14 -3
- statement_extractor/plugins/qualifiers/person.py +588 -14
- statement_extractor/plugins/qualifiers/sec_edgar.py +14 -3
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/t5_gemma.py +176 -75
- statement_extractor/plugins/taxonomy/embedding.py +193 -46
- statement_extractor/plugins/taxonomy/mnli.py +16 -4
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.5.0.dist-info/RECORD +0 -55
- statement_extractor/plugins/canonicalizers/__init__.py +0 -17
- statement_extractor/plugins/canonicalizers/base.py +0 -9
- statement_extractor/plugins/canonicalizers/location.py +0 -219
- statement_extractor/plugins/canonicalizers/organization.py +0 -230
- statement_extractor/plugins/canonicalizers/person.py +0 -242
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.3.dist-info}/WHEEL +0 -0
- {corp_extractor-0.5.0.dist-info → corp_extractor-0.9.3.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DocumentPipeline - Orchestrates document-level extraction with chunking and citations.
|
|
3
|
+
|
|
4
|
+
Wraps ExtractionPipeline to provide document-level features:
|
|
5
|
+
- Text chunking with page awareness
|
|
6
|
+
- Batch processing through pipeline stages
|
|
7
|
+
- Statement deduplication across chunks
|
|
8
|
+
- Citation generation from document metadata
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel, Field
|
|
16
|
+
|
|
17
|
+
from ..models.document import ChunkingConfig, Document
|
|
18
|
+
from ..pipeline import ExtractionPipeline, PipelineConfig
|
|
19
|
+
from .chunker import DocumentChunker
|
|
20
|
+
from .context import DocumentContext
|
|
21
|
+
from .deduplicator import StatementDeduplicator
|
|
22
|
+
from .summarizer import DocumentSummarizer
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DocumentPipelineConfig(BaseModel):
|
|
28
|
+
"""Configuration for document pipeline processing."""
|
|
29
|
+
|
|
30
|
+
chunking: ChunkingConfig = Field(
|
|
31
|
+
default_factory=ChunkingConfig,
|
|
32
|
+
description="Configuration for text chunking"
|
|
33
|
+
)
|
|
34
|
+
generate_summary: bool = Field(
|
|
35
|
+
default=True,
|
|
36
|
+
description="Whether to generate a document summary"
|
|
37
|
+
)
|
|
38
|
+
deduplicate_across_chunks: bool = Field(
|
|
39
|
+
default=True,
|
|
40
|
+
description="Whether to deduplicate statements across chunks"
|
|
41
|
+
)
|
|
42
|
+
batch_size: int = Field(
|
|
43
|
+
default=10,
|
|
44
|
+
ge=1,
|
|
45
|
+
description="Number of items to process in each batch"
|
|
46
|
+
)
|
|
47
|
+
pipeline_config: Optional[PipelineConfig] = Field(
|
|
48
|
+
default=None,
|
|
49
|
+
description="Configuration for the underlying extraction pipeline"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DocumentPipeline:
|
|
54
|
+
"""
|
|
55
|
+
Document-level extraction pipeline.
|
|
56
|
+
|
|
57
|
+
Processes documents through:
|
|
58
|
+
1. Summary generation (optional)
|
|
59
|
+
2. Chunking with page awareness
|
|
60
|
+
3. Batch extraction through all pipeline stages
|
|
61
|
+
4. Deduplication across chunks
|
|
62
|
+
5. Citation generation
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
>>> pipeline = DocumentPipeline()
|
|
66
|
+
>>> document = Document.from_text("Long document text...", title="Report")
|
|
67
|
+
>>> ctx = pipeline.process(document)
|
|
68
|
+
>>> for stmt in ctx.labeled_statements:
|
|
69
|
+
... print(f"{stmt}: {stmt.citation}")
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
config: Optional[DocumentPipelineConfig] = None,
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Initialize the document pipeline.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
config: Document pipeline configuration
|
|
81
|
+
"""
|
|
82
|
+
self.config = config or DocumentPipelineConfig()
|
|
83
|
+
|
|
84
|
+
# Initialize components
|
|
85
|
+
self._chunker = DocumentChunker(self.config.chunking)
|
|
86
|
+
self._deduplicator = StatementDeduplicator()
|
|
87
|
+
self._summarizer = DocumentSummarizer() if self.config.generate_summary else None
|
|
88
|
+
self._pipeline = ExtractionPipeline(self.config.pipeline_config)
|
|
89
|
+
|
|
90
|
+
def process(self, document: Document) -> DocumentContext:
|
|
91
|
+
"""
|
|
92
|
+
Process a document through the pipeline.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
document: Document to process
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
DocumentContext with all extraction results
|
|
99
|
+
"""
|
|
100
|
+
logger.info(f"Starting document pipeline: {document.document_id}")
|
|
101
|
+
start_time = time.time()
|
|
102
|
+
|
|
103
|
+
ctx = DocumentContext(document=document)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# Step 1: Generate summary (if enabled)
|
|
107
|
+
if self.config.generate_summary and self._summarizer:
|
|
108
|
+
self._generate_summary(document, ctx)
|
|
109
|
+
|
|
110
|
+
# Step 2: Chunk the document
|
|
111
|
+
chunks = self._chunker.chunk_document(document)
|
|
112
|
+
ctx.chunks = chunks
|
|
113
|
+
logger.info(f"Created {len(chunks)} chunks")
|
|
114
|
+
|
|
115
|
+
if not chunks:
|
|
116
|
+
logger.warning("No chunks created from document")
|
|
117
|
+
return ctx
|
|
118
|
+
|
|
119
|
+
# Step 3: Process all chunks through Stage 1 (Splitting)
|
|
120
|
+
self._process_stage1(ctx)
|
|
121
|
+
|
|
122
|
+
# Step 4: Deduplicate raw triples
|
|
123
|
+
if self.config.deduplicate_across_chunks:
|
|
124
|
+
self._deduplicate_triples(ctx)
|
|
125
|
+
|
|
126
|
+
# Step 5: Process through remaining stages (2-6)
|
|
127
|
+
self._process_remaining_stages(ctx)
|
|
128
|
+
|
|
129
|
+
# Step 6: Add citations to statements
|
|
130
|
+
self._add_citations(ctx)
|
|
131
|
+
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.exception("Document pipeline failed")
|
|
134
|
+
ctx.add_error(f"Pipeline error: {str(e)}")
|
|
135
|
+
raise
|
|
136
|
+
|
|
137
|
+
total_time = time.time() - start_time
|
|
138
|
+
ctx.record_timing("total", total_time)
|
|
139
|
+
|
|
140
|
+
logger.info(
|
|
141
|
+
f"Document pipeline complete: {ctx.statement_count} statements, "
|
|
142
|
+
f"{ctx.duplicates_removed} duplicates removed, "
|
|
143
|
+
f"{total_time:.2f}s"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return ctx
|
|
147
|
+
|
|
148
|
+
def _generate_summary(self, document: Document, ctx: DocumentContext) -> None:
|
|
149
|
+
"""Generate document summary."""
|
|
150
|
+
logger.info("Generating document summary")
|
|
151
|
+
start_time = time.time()
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
summary = self._summarizer.summarize(document)
|
|
155
|
+
document.summary = summary
|
|
156
|
+
logger.info(f"Generated summary: {len(summary)} chars")
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.error(f"Summary generation failed: {e}")
|
|
159
|
+
ctx.add_warning(f"Summary generation failed: {e}")
|
|
160
|
+
|
|
161
|
+
ctx.record_timing("summarization", time.time() - start_time)
|
|
162
|
+
|
|
163
|
+
def _process_stage1(self, ctx: DocumentContext) -> None:
|
|
164
|
+
"""Process all chunks through Stage 1 (Splitting) using batch processing."""
|
|
165
|
+
from ..pipeline.registry import PluginRegistry
|
|
166
|
+
from ..plugins.base import PluginCapability
|
|
167
|
+
|
|
168
|
+
logger.info(f"Processing {len(ctx.chunks)} chunks through Stage 1 (batch mode)")
|
|
169
|
+
start_time = time.time()
|
|
170
|
+
|
|
171
|
+
# Get the splitter plugin
|
|
172
|
+
splitters = PluginRegistry.get_splitters()
|
|
173
|
+
if not splitters:
|
|
174
|
+
logger.warning("No splitter plugins registered")
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
# Use first enabled splitter
|
|
178
|
+
splitter = None
|
|
179
|
+
for s in splitters:
|
|
180
|
+
plugin_enabled = (
|
|
181
|
+
self.config.pipeline_config is None or
|
|
182
|
+
self.config.pipeline_config.is_plugin_enabled(s.name)
|
|
183
|
+
)
|
|
184
|
+
if plugin_enabled:
|
|
185
|
+
splitter = s
|
|
186
|
+
break
|
|
187
|
+
|
|
188
|
+
if not splitter:
|
|
189
|
+
logger.warning("No enabled splitter plugin found")
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Extract all chunk texts
|
|
193
|
+
chunk_texts = [chunk.text for chunk in ctx.chunks]
|
|
194
|
+
|
|
195
|
+
# Create a dummy context for the splitter
|
|
196
|
+
from ..pipeline.context import PipelineContext
|
|
197
|
+
dummy_ctx = PipelineContext(
|
|
198
|
+
source_text="", # Not used for batch splitting
|
|
199
|
+
source_metadata=self.config.pipeline_config.model_dump() if self.config.pipeline_config else {},
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
all_triples = []
|
|
203
|
+
|
|
204
|
+
# Require batch processing capability
|
|
205
|
+
if PluginCapability.BATCH_PROCESSING not in splitter.capabilities:
|
|
206
|
+
raise RuntimeError(
|
|
207
|
+
f"Splitter plugin '{splitter.name}' does not support batch processing. "
|
|
208
|
+
"Document pipeline requires BATCH_PROCESSING capability for efficient GPU utilization."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
logger.info(f"Using batch splitting with {splitter.name}")
|
|
212
|
+
batch_results = splitter.split_batch(chunk_texts, dummy_ctx)
|
|
213
|
+
|
|
214
|
+
# Annotate triples with document/chunk info
|
|
215
|
+
for chunk, triples in zip(ctx.chunks, batch_results):
|
|
216
|
+
for triple in triples:
|
|
217
|
+
triple.document_id = ctx.document.document_id
|
|
218
|
+
triple.chunk_index = chunk.chunk_index
|
|
219
|
+
triple.page_number = chunk.primary_page
|
|
220
|
+
all_triples.append(triple)
|
|
221
|
+
|
|
222
|
+
ctx.raw_triples = all_triples
|
|
223
|
+
ctx.pre_dedup_count = len(all_triples)
|
|
224
|
+
|
|
225
|
+
ctx.record_timing("stage1_batch", time.time() - start_time)
|
|
226
|
+
logger.info(f"Stage 1 produced {len(all_triples)} raw triples from {len(ctx.chunks)} chunks")
|
|
227
|
+
|
|
228
|
+
def _deduplicate_triples(self, ctx: DocumentContext) -> None:
|
|
229
|
+
"""Deduplicate raw triples across chunks."""
|
|
230
|
+
logger.info(f"Deduplicating {len(ctx.raw_triples)} triples")
|
|
231
|
+
start_time = time.time()
|
|
232
|
+
|
|
233
|
+
original_count = len(ctx.raw_triples)
|
|
234
|
+
ctx.raw_triples = self._deduplicator.deduplicate_batch(ctx.raw_triples)
|
|
235
|
+
ctx.post_dedup_count = len(ctx.raw_triples)
|
|
236
|
+
|
|
237
|
+
removed = original_count - len(ctx.raw_triples)
|
|
238
|
+
ctx.record_timing("deduplication", time.time() - start_time)
|
|
239
|
+
logger.info(f"Removed {removed} duplicate triples")
|
|
240
|
+
|
|
241
|
+
def _process_remaining_stages(self, ctx: DocumentContext) -> None:
|
|
242
|
+
"""Process through stages 2-6."""
|
|
243
|
+
logger.info(f"Processing {len(ctx.raw_triples)} triples through stages 2-6")
|
|
244
|
+
start_time = time.time()
|
|
245
|
+
|
|
246
|
+
# Create a pipeline config for stages 2-6
|
|
247
|
+
# Exclude enabled_stages from base config to avoid duplicate keyword argument
|
|
248
|
+
base_config = {}
|
|
249
|
+
if self.config.pipeline_config:
|
|
250
|
+
base_config = self.config.pipeline_config.model_dump(exclude={"enabled_stages"})
|
|
251
|
+
stages_config = PipelineConfig(
|
|
252
|
+
enabled_stages={2, 3, 4, 5, 6},
|
|
253
|
+
**base_config
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Create a combined context with all raw triples
|
|
257
|
+
from ..pipeline.context import PipelineContext
|
|
258
|
+
|
|
259
|
+
combined_ctx = PipelineContext(
|
|
260
|
+
source_text=ctx.document.full_text,
|
|
261
|
+
source_metadata={
|
|
262
|
+
"document_id": ctx.document.document_id,
|
|
263
|
+
"title": ctx.document.metadata.title,
|
|
264
|
+
},
|
|
265
|
+
)
|
|
266
|
+
combined_ctx.raw_triples = ctx.raw_triples
|
|
267
|
+
|
|
268
|
+
# Run stages 2-6
|
|
269
|
+
pipeline = ExtractionPipeline(stages_config)
|
|
270
|
+
|
|
271
|
+
# We need to manually run stages since we're providing pre-existing triples
|
|
272
|
+
# Stage 2: Extraction
|
|
273
|
+
if stages_config.is_stage_enabled(2):
|
|
274
|
+
combined_ctx = pipeline._run_extraction(combined_ctx)
|
|
275
|
+
|
|
276
|
+
# Propagate document info to statements
|
|
277
|
+
for stmt in combined_ctx.statements:
|
|
278
|
+
# Find the source triple to get document info
|
|
279
|
+
for triple in ctx.raw_triples:
|
|
280
|
+
if triple.source_sentence in stmt.source_text:
|
|
281
|
+
stmt.document_id = triple.document_id
|
|
282
|
+
stmt.chunk_index = triple.chunk_index
|
|
283
|
+
stmt.page_number = triple.page_number
|
|
284
|
+
break
|
|
285
|
+
|
|
286
|
+
# Stage 3: Qualification
|
|
287
|
+
if stages_config.is_stage_enabled(3):
|
|
288
|
+
combined_ctx = pipeline._run_qualification(combined_ctx)
|
|
289
|
+
|
|
290
|
+
# Stage 4: Canonicalization
|
|
291
|
+
if stages_config.is_stage_enabled(4):
|
|
292
|
+
combined_ctx = pipeline._run_canonicalization(combined_ctx)
|
|
293
|
+
|
|
294
|
+
# Stage 5: Labeling
|
|
295
|
+
if stages_config.is_stage_enabled(5):
|
|
296
|
+
combined_ctx = pipeline._run_labeling(combined_ctx)
|
|
297
|
+
|
|
298
|
+
# Stage 6: Taxonomy
|
|
299
|
+
if stages_config.is_stage_enabled(6):
|
|
300
|
+
combined_ctx = pipeline._run_taxonomy(combined_ctx)
|
|
301
|
+
|
|
302
|
+
# Propagate document info to labeled statements
|
|
303
|
+
for labeled_stmt in combined_ctx.labeled_statements:
|
|
304
|
+
labeled_stmt.document_id = labeled_stmt.statement.document_id
|
|
305
|
+
labeled_stmt.page_number = labeled_stmt.statement.page_number
|
|
306
|
+
|
|
307
|
+
ctx.statements = combined_ctx.statements
|
|
308
|
+
ctx.labeled_statements = combined_ctx.labeled_statements
|
|
309
|
+
|
|
310
|
+
# Merge timings
|
|
311
|
+
for stage, duration in combined_ctx.stage_timings.items():
|
|
312
|
+
ctx.record_timing(stage, duration)
|
|
313
|
+
|
|
314
|
+
ctx.record_timing("stages_2_6_batch", time.time() - start_time)
|
|
315
|
+
logger.info(f"Stages 2-6 produced {len(ctx.labeled_statements)} labeled statements")
|
|
316
|
+
|
|
317
|
+
def _add_citations(self, ctx: DocumentContext) -> None:
|
|
318
|
+
"""Add citations to all labeled statements."""
|
|
319
|
+
logger.info("Adding citations to statements")
|
|
320
|
+
|
|
321
|
+
for stmt in ctx.labeled_statements:
|
|
322
|
+
citation = ctx.document.metadata.format_citation(stmt.page_number)
|
|
323
|
+
stmt.citation = citation if citation else None
|
|
324
|
+
|
|
325
|
+
def process_text(
|
|
326
|
+
self,
|
|
327
|
+
text: str,
|
|
328
|
+
title: Optional[str] = None,
|
|
329
|
+
**metadata_kwargs,
|
|
330
|
+
) -> DocumentContext:
|
|
331
|
+
"""
|
|
332
|
+
Process plain text through the document pipeline.
|
|
333
|
+
|
|
334
|
+
Convenience method that creates a Document from text.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
text: Text to process
|
|
338
|
+
title: Optional document title
|
|
339
|
+
**metadata_kwargs: Additional document metadata
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
DocumentContext with extraction results
|
|
343
|
+
"""
|
|
344
|
+
document = Document.from_text(text, title=title, **metadata_kwargs)
|
|
345
|
+
return self.process(document)
|
|
346
|
+
|
|
347
|
+
async def process_url(
|
|
348
|
+
self,
|
|
349
|
+
url: str,
|
|
350
|
+
loader_config: Optional["URLLoaderConfig"] = None,
|
|
351
|
+
) -> DocumentContext:
|
|
352
|
+
"""
|
|
353
|
+
Process a URL through the document pipeline.
|
|
354
|
+
|
|
355
|
+
Fetches the URL, extracts content (HTML or PDF), and processes it.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
url: URL to process
|
|
359
|
+
loader_config: Optional loader configuration
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
DocumentContext with extraction results
|
|
363
|
+
"""
|
|
364
|
+
from .loader import URLLoader, URLLoaderConfig
|
|
365
|
+
|
|
366
|
+
loader = URLLoader(loader_config or URLLoaderConfig())
|
|
367
|
+
document = await loader.load(url)
|
|
368
|
+
return self.process(document)
|
|
369
|
+
|
|
370
|
+
def process_url_sync(
|
|
371
|
+
self,
|
|
372
|
+
url: str,
|
|
373
|
+
loader_config: Optional["URLLoaderConfig"] = None,
|
|
374
|
+
) -> DocumentContext:
|
|
375
|
+
"""
|
|
376
|
+
Process a URL through the document pipeline (synchronous).
|
|
377
|
+
|
|
378
|
+
Fetches the URL, extracts content (HTML or PDF), and processes it.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
url: URL to process
|
|
382
|
+
loader_config: Optional loader configuration
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
DocumentContext with extraction results
|
|
386
|
+
"""
|
|
387
|
+
import asyncio
|
|
388
|
+
return asyncio.run(self.process_url(url, loader_config))
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DocumentSummarizer - Generate document summaries using Gemma3.
|
|
3
|
+
|
|
4
|
+
Creates concise summaries focused on entities, events, and relationships
|
|
5
|
+
that are useful for providing context during extraction.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from ..models.document import Document
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DocumentSummarizer:
|
|
17
|
+
"""
|
|
18
|
+
Generates document summaries using the Gemma3 LLM.
|
|
19
|
+
|
|
20
|
+
Summaries focus on:
|
|
21
|
+
- Key entities mentioned
|
|
22
|
+
- Important events and actions
|
|
23
|
+
- Relationships between entities
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
MAX_INPUT_TOKENS = 10_000
|
|
27
|
+
DEFAULT_MAX_OUTPUT_TOKENS = 300
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
max_input_tokens: int = MAX_INPUT_TOKENS,
|
|
32
|
+
max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize the summarizer.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
max_input_tokens: Maximum tokens of input to send to the LLM
|
|
39
|
+
max_output_tokens: Maximum tokens for the summary output
|
|
40
|
+
"""
|
|
41
|
+
self._max_input_tokens = max_input_tokens
|
|
42
|
+
self._max_output_tokens = max_output_tokens
|
|
43
|
+
self._llm = None
|
|
44
|
+
self._tokenizer = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def llm(self):
|
|
48
|
+
"""Lazy-load the LLM."""
|
|
49
|
+
if self._llm is None:
|
|
50
|
+
from ..llm import get_llm
|
|
51
|
+
logger.debug("Loading LLM for summarization")
|
|
52
|
+
self._llm = get_llm()
|
|
53
|
+
return self._llm
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def tokenizer(self):
|
|
57
|
+
"""Lazy-load tokenizer for token counting."""
|
|
58
|
+
if self._tokenizer is None:
|
|
59
|
+
from transformers import AutoTokenizer
|
|
60
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
61
|
+
"Corp-o-Rate-Community/statement-extractor",
|
|
62
|
+
trust_remote_code=True,
|
|
63
|
+
)
|
|
64
|
+
return self._tokenizer
|
|
65
|
+
|
|
66
|
+
def _count_tokens(self, text: str) -> int:
|
|
67
|
+
"""Count tokens in text."""
|
|
68
|
+
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
|
69
|
+
|
|
70
|
+
def _truncate_to_tokens(self, text: str, max_tokens: int) -> str:
|
|
71
|
+
"""
|
|
72
|
+
Truncate text to a maximum number of tokens.
|
|
73
|
+
|
|
74
|
+
Tries to truncate at sentence boundaries when possible.
|
|
75
|
+
"""
|
|
76
|
+
token_count = self._count_tokens(text)
|
|
77
|
+
|
|
78
|
+
if token_count <= max_tokens:
|
|
79
|
+
return text
|
|
80
|
+
|
|
81
|
+
# Estimate chars per token
|
|
82
|
+
chars_per_token = len(text) / token_count
|
|
83
|
+
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% buffer
|
|
84
|
+
|
|
85
|
+
# Truncate
|
|
86
|
+
truncated = text[:target_chars]
|
|
87
|
+
|
|
88
|
+
# Try to end at a sentence boundary
|
|
89
|
+
last_period = truncated.rfind(". ")
|
|
90
|
+
last_newline = truncated.rfind("\n")
|
|
91
|
+
split_pos = max(last_period, last_newline)
|
|
92
|
+
|
|
93
|
+
if split_pos > target_chars * 0.7: # Don't lose too much text
|
|
94
|
+
truncated = truncated[:split_pos + 1]
|
|
95
|
+
|
|
96
|
+
logger.debug(f"Truncated text from {len(text)} to {len(truncated)} chars")
|
|
97
|
+
return truncated
|
|
98
|
+
|
|
99
|
+
def summarize(
|
|
100
|
+
self,
|
|
101
|
+
document: Document,
|
|
102
|
+
custom_prompt: Optional[str] = None,
|
|
103
|
+
) -> str:
|
|
104
|
+
"""
|
|
105
|
+
Generate a summary of the document.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
document: Document to summarize
|
|
109
|
+
custom_prompt: Optional custom prompt (uses default if not provided)
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Summary string
|
|
113
|
+
"""
|
|
114
|
+
if not document.full_text.strip():
|
|
115
|
+
logger.warning("Cannot summarize empty document")
|
|
116
|
+
return ""
|
|
117
|
+
|
|
118
|
+
logger.info(f"Generating summary for document {document.document_id}")
|
|
119
|
+
|
|
120
|
+
# Truncate text to max input tokens
|
|
121
|
+
text = self._truncate_to_tokens(document.full_text, self._max_input_tokens)
|
|
122
|
+
|
|
123
|
+
# Build prompt
|
|
124
|
+
if custom_prompt:
|
|
125
|
+
prompt = f"{custom_prompt}\n\n{text}"
|
|
126
|
+
else:
|
|
127
|
+
prompt = self._build_prompt(text, document)
|
|
128
|
+
|
|
129
|
+
# Generate summary
|
|
130
|
+
try:
|
|
131
|
+
summary = self.llm.generate(
|
|
132
|
+
prompt=prompt,
|
|
133
|
+
max_tokens=self._max_output_tokens,
|
|
134
|
+
stop=["\n\n\n", "---"],
|
|
135
|
+
)
|
|
136
|
+
summary = summary.strip()
|
|
137
|
+
logger.info(f"Generated summary ({len(summary)} chars):")
|
|
138
|
+
# Log summary with indentation for readability
|
|
139
|
+
for line in summary.split("\n"):
|
|
140
|
+
logger.info(f" {line}")
|
|
141
|
+
return summary
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Summary generation failed: {e}")
|
|
145
|
+
raise
|
|
146
|
+
|
|
147
|
+
def _build_prompt(self, text: str, document: Document) -> str:
|
|
148
|
+
"""Build the summarization prompt."""
|
|
149
|
+
# Include document metadata context if available
|
|
150
|
+
context_parts = []
|
|
151
|
+
if document.metadata.title:
|
|
152
|
+
context_parts.append(f"Title: {document.metadata.title}")
|
|
153
|
+
if document.metadata.authors:
|
|
154
|
+
context_parts.append(f"Authors: {', '.join(document.metadata.authors)}")
|
|
155
|
+
if document.metadata.source_type:
|
|
156
|
+
context_parts.append(f"Source type: {document.metadata.source_type}")
|
|
157
|
+
|
|
158
|
+
context = "\n".join(context_parts) if context_parts else ""
|
|
159
|
+
|
|
160
|
+
prompt = f"""Summarize the following document, focusing on:
|
|
161
|
+
1. Key entities (companies, people, locations) mentioned
|
|
162
|
+
2. Important events, actions, and decisions
|
|
163
|
+
3. Relationships between entities
|
|
164
|
+
4. Main topics and themes
|
|
165
|
+
|
|
166
|
+
Keep the summary concise (2-3 paragraphs) and factual.
|
|
167
|
+
|
|
168
|
+
{context}
|
|
169
|
+
|
|
170
|
+
Document text:
|
|
171
|
+
{text}
|
|
172
|
+
|
|
173
|
+
Summary:"""
|
|
174
|
+
|
|
175
|
+
return prompt
|
|
176
|
+
|
|
177
|
+
def summarize_text(
|
|
178
|
+
self,
|
|
179
|
+
text: str,
|
|
180
|
+
title: Optional[str] = None,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Generate a summary from plain text.
|
|
184
|
+
|
|
185
|
+
Convenience method that creates a temporary Document.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
text: Text to summarize
|
|
189
|
+
title: Optional document title for context
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Summary string
|
|
193
|
+
"""
|
|
194
|
+
document = Document.from_text(text, title=title)
|
|
195
|
+
return self.summarize(document)
|
statement_extractor/extractor.py
CHANGED
|
@@ -392,7 +392,7 @@ class StatementExtractor:
|
|
|
392
392
|
This is the new extraction pipeline that:
|
|
393
393
|
1. Generates multiple candidates via DBS
|
|
394
394
|
2. Parses each to statements
|
|
395
|
-
3. Scores each triple for
|
|
395
|
+
3. Scores each triple for quality (semantic + entity)
|
|
396
396
|
4. Merges top beams or selects best beam
|
|
397
397
|
5. Deduplicates using embeddings (if enabled)
|
|
398
398
|
"""
|
|
@@ -43,10 +43,17 @@ else:
|
|
|
43
43
|
|
|
44
44
|
# New pipeline models
|
|
45
45
|
from .entity import ExtractedEntity
|
|
46
|
-
from .statement import RawTriple, PipelineStatement
|
|
47
|
-
from .qualifiers import EntityQualifiers, QualifiedEntity
|
|
46
|
+
from .statement import SplitSentence, RawTriple, PipelineStatement
|
|
47
|
+
from .qualifiers import EntityQualifiers, QualifiedEntity, ResolvedRole, ResolvedOrganization
|
|
48
48
|
from .canonical import CanonicalMatch, CanonicalEntity
|
|
49
49
|
from .labels import StatementLabel, LabeledStatement, TaxonomyResult
|
|
50
|
+
from .document import (
|
|
51
|
+
Document,
|
|
52
|
+
DocumentMetadata,
|
|
53
|
+
DocumentPage,
|
|
54
|
+
TextChunk,
|
|
55
|
+
ChunkingConfig,
|
|
56
|
+
)
|
|
50
57
|
|
|
51
58
|
__all__ = [
|
|
52
59
|
# Re-exported from original models.py (backward compatibility)
|
|
@@ -62,13 +69,22 @@ __all__ = [
|
|
|
62
69
|
"ExtractionOptions",
|
|
63
70
|
# New pipeline models
|
|
64
71
|
"ExtractedEntity",
|
|
65
|
-
"
|
|
72
|
+
"SplitSentence",
|
|
73
|
+
"RawTriple", # Backwards compatibility alias for SplitSentence
|
|
66
74
|
"PipelineStatement",
|
|
67
75
|
"EntityQualifiers",
|
|
68
76
|
"QualifiedEntity",
|
|
77
|
+
"ResolvedRole",
|
|
78
|
+
"ResolvedOrganization",
|
|
69
79
|
"CanonicalMatch",
|
|
70
80
|
"CanonicalEntity",
|
|
71
81
|
"StatementLabel",
|
|
72
82
|
"LabeledStatement",
|
|
73
83
|
"TaxonomyResult",
|
|
84
|
+
# Document models
|
|
85
|
+
"Document",
|
|
86
|
+
"DocumentMetadata",
|
|
87
|
+
"DocumentPage",
|
|
88
|
+
"TextChunk",
|
|
89
|
+
"ChunkingConfig",
|
|
74
90
|
]
|
|
@@ -64,9 +64,52 @@ class CanonicalEntity(BaseModel):
|
|
|
64
64
|
)
|
|
65
65
|
fqn: str = Field(
|
|
66
66
|
...,
|
|
67
|
-
description="Fully qualified name, e.g., '
|
|
67
|
+
description="Fully qualified name, e.g., 'AMAZON CORP INC (SEC-CIK,USA)'"
|
|
68
68
|
)
|
|
69
69
|
|
|
70
|
+
@property
|
|
71
|
+
def name(self) -> Optional[str]:
|
|
72
|
+
"""Get the canonical/legal name if available."""
|
|
73
|
+
# Prefer legal_name from qualifiers (set by embedding qualifier)
|
|
74
|
+
if self.qualified_entity.qualifiers.legal_name:
|
|
75
|
+
return self.qualified_entity.qualifiers.legal_name
|
|
76
|
+
# Fall back to canonical match name
|
|
77
|
+
if self.canonical_match and self.canonical_match.canonical_name:
|
|
78
|
+
return self.canonical_match.canonical_name
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def qualifiers_dict(self) -> Optional[dict[str, str]]:
|
|
83
|
+
"""
|
|
84
|
+
Get qualifiers as a dict for serialization.
|
|
85
|
+
|
|
86
|
+
Returns a dict with keys like: legal_name, region, source, source_id
|
|
87
|
+
Only returns non-None values. Returns None if no qualifiers are set.
|
|
88
|
+
"""
|
|
89
|
+
qualifiers = self.qualified_entity.qualifiers
|
|
90
|
+
identifiers = qualifiers.identifiers
|
|
91
|
+
result = {}
|
|
92
|
+
|
|
93
|
+
# Add legal name
|
|
94
|
+
if qualifiers.legal_name:
|
|
95
|
+
result["legal_name"] = qualifiers.legal_name
|
|
96
|
+
|
|
97
|
+
# Add region (prefer region, fall back to jurisdiction/country)
|
|
98
|
+
if qualifiers.region:
|
|
99
|
+
result["region"] = qualifiers.region
|
|
100
|
+
elif qualifiers.jurisdiction:
|
|
101
|
+
result["region"] = qualifiers.jurisdiction
|
|
102
|
+
elif qualifiers.country:
|
|
103
|
+
result["region"] = qualifiers.country
|
|
104
|
+
|
|
105
|
+
# Add source and source_id from identifiers
|
|
106
|
+
if "source" in identifiers:
|
|
107
|
+
result["source"] = identifiers["source"]
|
|
108
|
+
if "source_id" in identifiers:
|
|
109
|
+
result["source_id"] = identifiers["source_id"]
|
|
110
|
+
|
|
111
|
+
return result if result else None
|
|
112
|
+
|
|
70
113
|
@classmethod
|
|
71
114
|
def from_qualified(
|
|
72
115
|
cls,
|