corp-extractor 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- corp_extractor-0.2.3.dist-info/METADATA +280 -0
- corp_extractor-0.2.3.dist-info/RECORD +9 -0
- corp_extractor-0.2.3.dist-info/WHEEL +4 -0
- statement_extractor/__init__.py +110 -0
- statement_extractor/canonicalization.py +196 -0
- statement_extractor/extractor.py +649 -0
- statement_extractor/models.py +284 -0
- statement_extractor/predicate_comparer.py +611 -0
- statement_extractor/scoring.py +419 -0
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Statement Extractor - Extract structured statements from text using T5-Gemma 2.
|
|
3
|
+
|
|
4
|
+
This module uses Diverse Beam Search (Vijayakumar et al., 2016) to generate
|
|
5
|
+
multiple candidate extractions and selects/merges the best results using
|
|
6
|
+
quality scoring.
|
|
7
|
+
|
|
8
|
+
Paper: https://arxiv.org/abs/1610.02424
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import re
|
|
13
|
+
import xml.etree.ElementTree as ET
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
18
|
+
|
|
19
|
+
from .models import (
|
|
20
|
+
Entity,
|
|
21
|
+
EntityType,
|
|
22
|
+
ExtractionOptions,
|
|
23
|
+
ExtractionResult,
|
|
24
|
+
PredicateComparisonConfig,
|
|
25
|
+
PredicateTaxonomy,
|
|
26
|
+
ScoringConfig,
|
|
27
|
+
Statement,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
# Default model
|
|
33
|
+
DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class StatementExtractor:
|
|
37
|
+
"""
|
|
38
|
+
Extract structured statements from unstructured text.
|
|
39
|
+
|
|
40
|
+
Uses the T5-Gemma 2 statement extraction model with Diverse Beam Search
|
|
41
|
+
to generate high-quality subject-predicate-object triples.
|
|
42
|
+
|
|
43
|
+
Features:
|
|
44
|
+
- Quality-based beam scoring (not just longest output)
|
|
45
|
+
- Beam merging for better coverage
|
|
46
|
+
- Embedding-based predicate comparison for smart deduplication
|
|
47
|
+
- Configurable precision/recall tradeoff
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> extractor = StatementExtractor()
|
|
51
|
+
>>> result = extractor.extract("Apple Inc. announced a new iPhone today.")
|
|
52
|
+
>>> for stmt in result:
|
|
53
|
+
... print(stmt)
|
|
54
|
+
Apple Inc. -- announced --> a new iPhone
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
model_id: str = DEFAULT_MODEL_ID,
|
|
60
|
+
device: Optional[str] = None,
|
|
61
|
+
torch_dtype: Optional[torch.dtype] = None,
|
|
62
|
+
predicate_taxonomy: Optional[PredicateTaxonomy] = None,
|
|
63
|
+
predicate_config: Optional[PredicateComparisonConfig] = None,
|
|
64
|
+
scoring_config: Optional[ScoringConfig] = None,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Initialize the statement extractor.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model_id: HuggingFace model ID or local path
|
|
71
|
+
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
|
72
|
+
torch_dtype: Torch dtype (default: bfloat16 on GPU, float32 on CPU)
|
|
73
|
+
predicate_taxonomy: Optional taxonomy for predicate normalization
|
|
74
|
+
predicate_config: Configuration for predicate comparison
|
|
75
|
+
scoring_config: Configuration for quality scoring
|
|
76
|
+
"""
|
|
77
|
+
self.model_id = model_id
|
|
78
|
+
self._model: Optional[AutoModelForSeq2SeqLM] = None
|
|
79
|
+
self._tokenizer: Optional[AutoTokenizer] = None
|
|
80
|
+
|
|
81
|
+
# Auto-detect device
|
|
82
|
+
if device is None:
|
|
83
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
84
|
+
else:
|
|
85
|
+
self.device = device
|
|
86
|
+
|
|
87
|
+
# Auto-detect dtype
|
|
88
|
+
if torch_dtype is None:
|
|
89
|
+
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
|
|
90
|
+
else:
|
|
91
|
+
self.torch_dtype = torch_dtype
|
|
92
|
+
|
|
93
|
+
# Scoring and comparison config
|
|
94
|
+
self._predicate_taxonomy = predicate_taxonomy
|
|
95
|
+
self._predicate_config = predicate_config
|
|
96
|
+
self._scoring_config = scoring_config
|
|
97
|
+
|
|
98
|
+
# Lazy-loaded components
|
|
99
|
+
self._beam_scorer = None
|
|
100
|
+
self._predicate_comparer = None
|
|
101
|
+
|
|
102
|
+
def _load_model(self) -> None:
|
|
103
|
+
"""Load model and tokenizer if not already loaded."""
|
|
104
|
+
if self._model is not None:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
logger.info(f"Loading model: {self.model_id}")
|
|
108
|
+
|
|
109
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
110
|
+
self.model_id,
|
|
111
|
+
trust_remote_code=True,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if self.device == "cuda":
|
|
115
|
+
self._model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
116
|
+
self.model_id,
|
|
117
|
+
torch_dtype=self.torch_dtype,
|
|
118
|
+
trust_remote_code=True,
|
|
119
|
+
device_map="auto",
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
self._model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
123
|
+
self.model_id,
|
|
124
|
+
trust_remote_code=True,
|
|
125
|
+
)
|
|
126
|
+
self._model = self._model.to(self.device)
|
|
127
|
+
|
|
128
|
+
logger.info(f"Model loaded on {self.device}")
|
|
129
|
+
|
|
130
|
+
def _get_beam_scorer(self, options: ExtractionOptions):
|
|
131
|
+
"""Get or create beam scorer with current config."""
|
|
132
|
+
from .scoring import BeamScorer
|
|
133
|
+
|
|
134
|
+
config = options.scoring_config or self._scoring_config or ScoringConfig()
|
|
135
|
+
return BeamScorer(config=config)
|
|
136
|
+
|
|
137
|
+
def _get_predicate_comparer(self, options: ExtractionOptions):
|
|
138
|
+
"""Get or create predicate comparer if embeddings enabled."""
|
|
139
|
+
if not options.embedding_dedup:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
from .predicate_comparer import PredicateComparer
|
|
143
|
+
|
|
144
|
+
taxonomy = options.predicate_taxonomy or self._predicate_taxonomy
|
|
145
|
+
config = options.predicate_config or self._predicate_config or PredicateComparisonConfig()
|
|
146
|
+
return PredicateComparer(taxonomy=taxonomy, config=config, device=self.device)
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def model(self) -> AutoModelForSeq2SeqLM:
|
|
150
|
+
"""Get the model, loading it if necessary."""
|
|
151
|
+
self._load_model()
|
|
152
|
+
return self._model
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def tokenizer(self) -> AutoTokenizer:
|
|
156
|
+
"""Get the tokenizer, loading it if necessary."""
|
|
157
|
+
self._load_model()
|
|
158
|
+
return self._tokenizer
|
|
159
|
+
|
|
160
|
+
def extract(
|
|
161
|
+
self,
|
|
162
|
+
text: str,
|
|
163
|
+
options: Optional[ExtractionOptions] = None,
|
|
164
|
+
) -> ExtractionResult:
|
|
165
|
+
"""
|
|
166
|
+
Extract statements from text.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
text: Input text to extract statements from
|
|
170
|
+
options: Extraction options (uses defaults if not provided)
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
ExtractionResult containing the extracted statements
|
|
174
|
+
"""
|
|
175
|
+
if options is None:
|
|
176
|
+
options = ExtractionOptions()
|
|
177
|
+
|
|
178
|
+
# Store original text for scoring
|
|
179
|
+
original_text = text
|
|
180
|
+
|
|
181
|
+
# Wrap text in page tags if not already wrapped
|
|
182
|
+
if not text.startswith("<page>"):
|
|
183
|
+
text = f"<page>{text}</page>"
|
|
184
|
+
|
|
185
|
+
# Run extraction with retry logic
|
|
186
|
+
statements = self._extract_with_scoring(text, original_text, options)
|
|
187
|
+
|
|
188
|
+
return ExtractionResult(
|
|
189
|
+
statements=statements,
|
|
190
|
+
source_text=original_text,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def extract_as_xml(
|
|
194
|
+
self,
|
|
195
|
+
text: str,
|
|
196
|
+
options: Optional[ExtractionOptions] = None,
|
|
197
|
+
) -> str:
|
|
198
|
+
"""
|
|
199
|
+
Extract statements and return raw XML output.
|
|
200
|
+
|
|
201
|
+
Note: This bypasses the new scoring/merging logic for backward compatibility.
|
|
202
|
+
Use extract() for full quality scoring.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
text: Input text to extract statements from
|
|
206
|
+
options: Extraction options
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
XML string with <statements> containing <stmt> elements
|
|
210
|
+
"""
|
|
211
|
+
if options is None:
|
|
212
|
+
options = ExtractionOptions()
|
|
213
|
+
|
|
214
|
+
if not text.startswith("<page>"):
|
|
215
|
+
text = f"<page>{text}</page>"
|
|
216
|
+
|
|
217
|
+
return self._extract_raw_xml(text, options)
|
|
218
|
+
|
|
219
|
+
def extract_as_json(
|
|
220
|
+
self,
|
|
221
|
+
text: str,
|
|
222
|
+
options: Optional[ExtractionOptions] = None,
|
|
223
|
+
indent: Optional[int] = 2,
|
|
224
|
+
) -> str:
|
|
225
|
+
"""
|
|
226
|
+
Extract statements and return JSON string.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
text: Input text to extract statements from
|
|
230
|
+
options: Extraction options
|
|
231
|
+
indent: JSON indentation (None for compact)
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
JSON string representation of the extraction result
|
|
235
|
+
"""
|
|
236
|
+
result = self.extract(text, options)
|
|
237
|
+
return result.model_dump_json(indent=indent)
|
|
238
|
+
|
|
239
|
+
def extract_as_dict(
|
|
240
|
+
self,
|
|
241
|
+
text: str,
|
|
242
|
+
options: Optional[ExtractionOptions] = None,
|
|
243
|
+
) -> dict:
|
|
244
|
+
"""
|
|
245
|
+
Extract statements and return as dictionary.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
text: Input text to extract statements from
|
|
249
|
+
options: Extraction options
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Dictionary representation of the extraction result
|
|
253
|
+
"""
|
|
254
|
+
result = self.extract(text, options)
|
|
255
|
+
return result.model_dump()
|
|
256
|
+
|
|
257
|
+
def _extract_with_scoring(
|
|
258
|
+
self,
|
|
259
|
+
text: str,
|
|
260
|
+
original_text: str,
|
|
261
|
+
options: ExtractionOptions,
|
|
262
|
+
) -> list[Statement]:
|
|
263
|
+
"""
|
|
264
|
+
Extract statements with quality scoring and beam merging.
|
|
265
|
+
|
|
266
|
+
This is the new extraction pipeline that:
|
|
267
|
+
1. Generates multiple candidates via DBS
|
|
268
|
+
2. Parses each to statements
|
|
269
|
+
3. Scores each triple for groundedness
|
|
270
|
+
4. Merges top beams or selects best beam
|
|
271
|
+
5. Deduplicates using embeddings (if enabled)
|
|
272
|
+
"""
|
|
273
|
+
# Tokenize input
|
|
274
|
+
inputs = self.tokenizer(
|
|
275
|
+
text,
|
|
276
|
+
return_tensors="pt",
|
|
277
|
+
max_length=4096,
|
|
278
|
+
truncation=True,
|
|
279
|
+
).to(self.device)
|
|
280
|
+
|
|
281
|
+
# Count sentences for quality check
|
|
282
|
+
num_sentences = self._count_sentences(text)
|
|
283
|
+
min_expected = int(num_sentences * options.min_statement_ratio)
|
|
284
|
+
|
|
285
|
+
logger.info(f"Input has ~{num_sentences} sentences, expecting >= {min_expected} statements")
|
|
286
|
+
|
|
287
|
+
# Get beam scorer
|
|
288
|
+
beam_scorer = self._get_beam_scorer(options)
|
|
289
|
+
|
|
290
|
+
all_candidates: list[list[Statement]] = []
|
|
291
|
+
|
|
292
|
+
for attempt in range(options.max_attempts):
|
|
293
|
+
# Generate candidate beams
|
|
294
|
+
candidates = self._generate_candidate_beams(inputs, options)
|
|
295
|
+
|
|
296
|
+
# Parse each candidate to statements
|
|
297
|
+
parsed_candidates = []
|
|
298
|
+
for xml_output in candidates:
|
|
299
|
+
statements = self._parse_xml_to_statements(xml_output)
|
|
300
|
+
if statements:
|
|
301
|
+
parsed_candidates.append(statements)
|
|
302
|
+
|
|
303
|
+
all_candidates.extend(parsed_candidates)
|
|
304
|
+
|
|
305
|
+
# Check if we have enough statements
|
|
306
|
+
total_stmts = sum(len(c) for c in parsed_candidates)
|
|
307
|
+
logger.info(f"Attempt {attempt + 1}/{options.max_attempts}: {len(parsed_candidates)} beams, {total_stmts} total statements")
|
|
308
|
+
|
|
309
|
+
if total_stmts >= min_expected:
|
|
310
|
+
break
|
|
311
|
+
|
|
312
|
+
if not all_candidates:
|
|
313
|
+
return []
|
|
314
|
+
|
|
315
|
+
# Select or merge beams
|
|
316
|
+
if options.merge_beams:
|
|
317
|
+
statements = beam_scorer.merge_beams(all_candidates, original_text)
|
|
318
|
+
else:
|
|
319
|
+
statements = beam_scorer.select_best_beam(all_candidates, original_text)
|
|
320
|
+
|
|
321
|
+
# Apply embedding-based deduplication if enabled
|
|
322
|
+
if options.embedding_dedup and options.deduplicate:
|
|
323
|
+
try:
|
|
324
|
+
comparer = self._get_predicate_comparer(options)
|
|
325
|
+
if comparer:
|
|
326
|
+
statements = comparer.deduplicate_statements(
|
|
327
|
+
statements,
|
|
328
|
+
entity_canonicalizer=options.entity_canonicalizer
|
|
329
|
+
)
|
|
330
|
+
# Also normalize predicates if taxonomy provided
|
|
331
|
+
if options.predicate_taxonomy or self._predicate_taxonomy:
|
|
332
|
+
statements = comparer.normalize_predicates(statements)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
logger.warning(f"Embedding deduplication failed, falling back to exact match: {e}")
|
|
335
|
+
statements = self._deduplicate_statements_exact(statements, options)
|
|
336
|
+
elif options.deduplicate:
|
|
337
|
+
statements = self._deduplicate_statements_exact(statements, options)
|
|
338
|
+
|
|
339
|
+
return statements
|
|
340
|
+
|
|
341
|
+
def _generate_candidate_beams(
|
|
342
|
+
self,
|
|
343
|
+
inputs,
|
|
344
|
+
options: ExtractionOptions,
|
|
345
|
+
) -> list[str]:
|
|
346
|
+
"""Generate multiple candidate beams using diverse beam search."""
|
|
347
|
+
num_seqs = options.num_beams
|
|
348
|
+
|
|
349
|
+
with torch.no_grad():
|
|
350
|
+
outputs = self.model.generate(
|
|
351
|
+
**inputs,
|
|
352
|
+
max_new_tokens=options.max_new_tokens,
|
|
353
|
+
num_beams=num_seqs,
|
|
354
|
+
num_beam_groups=num_seqs,
|
|
355
|
+
num_return_sequences=num_seqs,
|
|
356
|
+
diversity_penalty=options.diversity_penalty,
|
|
357
|
+
do_sample=False,
|
|
358
|
+
trust_remote_code=True,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Decode and process candidates
|
|
362
|
+
end_tag = "</statements>"
|
|
363
|
+
candidates: list[str] = []
|
|
364
|
+
|
|
365
|
+
for i, output in enumerate(outputs):
|
|
366
|
+
decoded = self.tokenizer.decode(output, skip_special_tokens=True)
|
|
367
|
+
output_len = len(output)
|
|
368
|
+
|
|
369
|
+
# Truncate at </statements>
|
|
370
|
+
if end_tag in decoded:
|
|
371
|
+
end_pos = decoded.find(end_tag) + len(end_tag)
|
|
372
|
+
decoded = decoded[:end_pos]
|
|
373
|
+
candidates.append(decoded)
|
|
374
|
+
logger.debug(f"Beam {i}: {output_len} tokens, found end tag, {len(decoded)} chars")
|
|
375
|
+
else:
|
|
376
|
+
# Log the issue - likely truncated
|
|
377
|
+
logger.warning(f"Beam {i}: {output_len} tokens, NO end tag found (truncated?)")
|
|
378
|
+
logger.warning(f"Beam {i} full output ({len(decoded)} chars):\n{decoded}")
|
|
379
|
+
|
|
380
|
+
# Include fallback if no valid candidates
|
|
381
|
+
if not candidates and len(outputs) > 0:
|
|
382
|
+
fallback = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
383
|
+
logger.warning(f"Using fallback beam (no valid candidates found), {len(fallback)} chars")
|
|
384
|
+
candidates.append(fallback)
|
|
385
|
+
|
|
386
|
+
return candidates
|
|
387
|
+
|
|
388
|
+
def _extract_raw_xml(
|
|
389
|
+
self,
|
|
390
|
+
text: str,
|
|
391
|
+
options: ExtractionOptions,
|
|
392
|
+
) -> str:
|
|
393
|
+
"""
|
|
394
|
+
Extract and return raw XML (legacy method for backward compatibility).
|
|
395
|
+
|
|
396
|
+
Uses length-based selection like the original implementation.
|
|
397
|
+
"""
|
|
398
|
+
# Tokenize input
|
|
399
|
+
inputs = self.tokenizer(
|
|
400
|
+
text,
|
|
401
|
+
return_tensors="pt",
|
|
402
|
+
max_length=4096,
|
|
403
|
+
truncation=True,
|
|
404
|
+
).to(self.device)
|
|
405
|
+
|
|
406
|
+
num_sentences = self._count_sentences(text)
|
|
407
|
+
min_expected = int(num_sentences * options.min_statement_ratio)
|
|
408
|
+
|
|
409
|
+
all_results: list[tuple[str, int]] = []
|
|
410
|
+
|
|
411
|
+
for attempt in range(options.max_attempts):
|
|
412
|
+
candidates = self._generate_candidate_beams(inputs, options)
|
|
413
|
+
|
|
414
|
+
for candidate in candidates:
|
|
415
|
+
if options.deduplicate:
|
|
416
|
+
candidate = self._deduplicate_xml(candidate)
|
|
417
|
+
num_stmts = self._count_statements(candidate)
|
|
418
|
+
all_results.append((candidate, num_stmts))
|
|
419
|
+
|
|
420
|
+
best_so_far = max(all_results, key=lambda x: x[1])[1] if all_results else 0
|
|
421
|
+
if best_so_far >= min_expected:
|
|
422
|
+
break
|
|
423
|
+
|
|
424
|
+
if not all_results:
|
|
425
|
+
return "<statements></statements>"
|
|
426
|
+
|
|
427
|
+
# Select best result (longest, for backward compatibility)
|
|
428
|
+
return max(all_results, key=lambda x: len(x[0]))[0]
|
|
429
|
+
|
|
430
|
+
def _deduplicate_statements_exact(
|
|
431
|
+
self,
|
|
432
|
+
statements: list[Statement],
|
|
433
|
+
options: ExtractionOptions,
|
|
434
|
+
) -> list[Statement]:
|
|
435
|
+
"""Deduplicate statements using exact text matching."""
|
|
436
|
+
from .canonicalization import deduplicate_statements_exact
|
|
437
|
+
return deduplicate_statements_exact(
|
|
438
|
+
statements,
|
|
439
|
+
entity_canonicalizer=options.entity_canonicalizer
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
def _deduplicate_xml(self, xml_output: str) -> str:
|
|
443
|
+
"""Remove duplicate <stmt> blocks from XML output (legacy method)."""
|
|
444
|
+
try:
|
|
445
|
+
root = ET.fromstring(xml_output)
|
|
446
|
+
except ET.ParseError:
|
|
447
|
+
return xml_output
|
|
448
|
+
|
|
449
|
+
if root.tag != 'statements':
|
|
450
|
+
return xml_output
|
|
451
|
+
|
|
452
|
+
seen: set[tuple[str, str, str]] = set()
|
|
453
|
+
unique_stmts: list[ET.Element] = []
|
|
454
|
+
|
|
455
|
+
for stmt in root.findall('stmt'):
|
|
456
|
+
subject = stmt.findtext('subject', '').strip().lower()
|
|
457
|
+
predicate = stmt.findtext('predicate', '').strip().lower()
|
|
458
|
+
obj = stmt.findtext('object', '').strip().lower()
|
|
459
|
+
key = (subject, predicate, obj)
|
|
460
|
+
|
|
461
|
+
if key not in seen:
|
|
462
|
+
seen.add(key)
|
|
463
|
+
unique_stmts.append(stmt)
|
|
464
|
+
|
|
465
|
+
new_root = ET.Element('statements')
|
|
466
|
+
for stmt in unique_stmts:
|
|
467
|
+
new_root.append(stmt)
|
|
468
|
+
|
|
469
|
+
return ET.tostring(new_root, encoding='unicode')
|
|
470
|
+
|
|
471
|
+
def _parse_xml_to_statements(self, xml_output: str) -> list[Statement]:
|
|
472
|
+
"""Parse XML output into Statement objects."""
|
|
473
|
+
statements: list[Statement] = []
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
root = ET.fromstring(xml_output)
|
|
477
|
+
except ET.ParseError as e:
|
|
478
|
+
# Log full output for debugging
|
|
479
|
+
logger.warning(f"Failed to parse XML output: {e}")
|
|
480
|
+
logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
|
|
481
|
+
return statements
|
|
482
|
+
|
|
483
|
+
if root.tag != 'statements':
|
|
484
|
+
return statements
|
|
485
|
+
|
|
486
|
+
for stmt_elem in root.findall('stmt'):
|
|
487
|
+
try:
|
|
488
|
+
# Parse subject
|
|
489
|
+
subject_elem = stmt_elem.find('subject')
|
|
490
|
+
subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
|
|
491
|
+
subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
|
|
492
|
+
|
|
493
|
+
# Parse object
|
|
494
|
+
object_elem = stmt_elem.find('object')
|
|
495
|
+
object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
|
|
496
|
+
object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
|
|
497
|
+
|
|
498
|
+
# Parse predicate
|
|
499
|
+
predicate_elem = stmt_elem.find('predicate')
|
|
500
|
+
predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
501
|
+
|
|
502
|
+
# Parse source text
|
|
503
|
+
text_elem = stmt_elem.find('text')
|
|
504
|
+
source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
|
|
505
|
+
|
|
506
|
+
if subject_text and predicate and object_text:
|
|
507
|
+
statements.append(Statement(
|
|
508
|
+
subject=Entity(text=subject_text, type=subject_type),
|
|
509
|
+
predicate=predicate,
|
|
510
|
+
object=Entity(text=object_text, type=object_type),
|
|
511
|
+
source_text=source_text,
|
|
512
|
+
))
|
|
513
|
+
except Exception as e:
|
|
514
|
+
logger.warning(f"Failed to parse statement: {e}")
|
|
515
|
+
continue
|
|
516
|
+
|
|
517
|
+
return statements
|
|
518
|
+
|
|
519
|
+
def _parse_entity_type(self, type_str: Optional[str]) -> EntityType:
|
|
520
|
+
"""Parse entity type string to EntityType enum."""
|
|
521
|
+
if type_str is None:
|
|
522
|
+
return EntityType.UNKNOWN
|
|
523
|
+
try:
|
|
524
|
+
return EntityType(type_str.upper())
|
|
525
|
+
except ValueError:
|
|
526
|
+
return EntityType.UNKNOWN
|
|
527
|
+
|
|
528
|
+
@staticmethod
|
|
529
|
+
def _count_sentences(text: str) -> int:
|
|
530
|
+
"""Count approximate number of sentences in text."""
|
|
531
|
+
clean_text = re.sub(r'<[^>]+>', '', text)
|
|
532
|
+
sentences = re.split(r'[.!?]+', clean_text)
|
|
533
|
+
sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
|
|
534
|
+
return max(1, len(sentences))
|
|
535
|
+
|
|
536
|
+
@staticmethod
|
|
537
|
+
def _count_statements(xml_output: str) -> int:
|
|
538
|
+
"""Count number of <stmt> tags in output."""
|
|
539
|
+
return len(re.findall(r'<stmt>', xml_output))
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
# Convenience functions for simple usage
|
|
543
|
+
|
|
544
|
+
_default_extractor: Optional[StatementExtractor] = None
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def _get_default_extractor() -> StatementExtractor:
|
|
548
|
+
"""Get or create the default extractor instance."""
|
|
549
|
+
global _default_extractor
|
|
550
|
+
if _default_extractor is None:
|
|
551
|
+
_default_extractor = StatementExtractor()
|
|
552
|
+
return _default_extractor
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def extract_statements(
|
|
556
|
+
text: str,
|
|
557
|
+
options: Optional[ExtractionOptions] = None,
|
|
558
|
+
**kwargs,
|
|
559
|
+
) -> ExtractionResult:
|
|
560
|
+
"""
|
|
561
|
+
Extract structured statements from text.
|
|
562
|
+
|
|
563
|
+
This is a convenience function that uses a default StatementExtractor instance.
|
|
564
|
+
For more control, create your own StatementExtractor.
|
|
565
|
+
|
|
566
|
+
By default, uses embedding-based deduplication and beam merging for
|
|
567
|
+
high-quality extraction. Requires sentence-transformers package.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
text: Input text to extract statements from
|
|
571
|
+
options: Extraction options (or pass individual options as kwargs)
|
|
572
|
+
**kwargs: Individual option overrides (num_beams, diversity_penalty, etc.)
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
ExtractionResult containing Statement objects
|
|
576
|
+
|
|
577
|
+
Example:
|
|
578
|
+
>>> result = extract_statements("Apple announced a new product.")
|
|
579
|
+
>>> for stmt in result:
|
|
580
|
+
... print(f"{stmt.subject.text} -> {stmt.predicate} -> {stmt.object.text}")
|
|
581
|
+
"""
|
|
582
|
+
if options is None and kwargs:
|
|
583
|
+
options = ExtractionOptions(**kwargs)
|
|
584
|
+
return _get_default_extractor().extract(text, options)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def extract_statements_as_xml(
|
|
588
|
+
text: str,
|
|
589
|
+
options: Optional[ExtractionOptions] = None,
|
|
590
|
+
**kwargs,
|
|
591
|
+
) -> str:
|
|
592
|
+
"""
|
|
593
|
+
Extract statements and return raw XML output.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
text: Input text to extract statements from
|
|
597
|
+
options: Extraction options
|
|
598
|
+
**kwargs: Individual option overrides
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
XML string with <statements> containing <stmt> elements
|
|
602
|
+
"""
|
|
603
|
+
if options is None and kwargs:
|
|
604
|
+
options = ExtractionOptions(**kwargs)
|
|
605
|
+
return _get_default_extractor().extract_as_xml(text, options)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def extract_statements_as_json(
|
|
609
|
+
text: str,
|
|
610
|
+
options: Optional[ExtractionOptions] = None,
|
|
611
|
+
indent: Optional[int] = 2,
|
|
612
|
+
**kwargs,
|
|
613
|
+
) -> str:
|
|
614
|
+
"""
|
|
615
|
+
Extract statements and return JSON string.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
text: Input text to extract statements from
|
|
619
|
+
options: Extraction options
|
|
620
|
+
indent: JSON indentation (None for compact)
|
|
621
|
+
**kwargs: Individual option overrides
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
JSON string representation of the extraction result
|
|
625
|
+
"""
|
|
626
|
+
if options is None and kwargs:
|
|
627
|
+
options = ExtractionOptions(**kwargs)
|
|
628
|
+
return _get_default_extractor().extract_as_json(text, options, indent)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def extract_statements_as_dict(
|
|
632
|
+
text: str,
|
|
633
|
+
options: Optional[ExtractionOptions] = None,
|
|
634
|
+
**kwargs,
|
|
635
|
+
) -> dict:
|
|
636
|
+
"""
|
|
637
|
+
Extract statements and return as dictionary.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
text: Input text to extract statements from
|
|
641
|
+
options: Extraction options
|
|
642
|
+
**kwargs: Individual option overrides
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
Dictionary representation of the extraction result
|
|
646
|
+
"""
|
|
647
|
+
if options is None and kwargs:
|
|
648
|
+
options = ExtractionOptions(**kwargs)
|
|
649
|
+
return _get_default_extractor().extract_as_dict(text, options)
|