corp-extractor 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of corp-extractor might be problematic. Click here for more details.
- corp_extractor-0.2.7.dist-info/METADATA +377 -0
- corp_extractor-0.2.7.dist-info/RECORD +11 -0
- corp_extractor-0.2.7.dist-info/WHEEL +4 -0
- corp_extractor-0.2.7.dist-info/entry_points.txt +3 -0
- statement_extractor/__init__.py +110 -0
- statement_extractor/canonicalization.py +196 -0
- statement_extractor/cli.py +215 -0
- statement_extractor/extractor.py +649 -0
- statement_extractor/models.py +284 -0
- statement_extractor/predicate_comparer.py +611 -0
- statement_extractor/scoring.py +419 -0
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Scoring module for statement extraction quality assessment.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- TripleScorer: Score individual triples for groundedness
|
|
6
|
+
- BeamScorer: Score and select/merge beams based on quality metrics
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from .models import ScoringConfig, Statement
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TripleScorer:
|
|
15
|
+
"""
|
|
16
|
+
Score individual triples for groundedness in source text.
|
|
17
|
+
|
|
18
|
+
Groundedness is measured by checking:
|
|
19
|
+
- Subject text appears in source
|
|
20
|
+
- Object text appears in source
|
|
21
|
+
- Subject and object are in proximity (same/nearby sentences)
|
|
22
|
+
- Evidence span exists and is valid
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: Optional[ScoringConfig] = None):
|
|
26
|
+
self.config = config or ScoringConfig()
|
|
27
|
+
|
|
28
|
+
def score_triple(self, statement: Statement, source_text: str) -> float:
|
|
29
|
+
"""
|
|
30
|
+
Score a triple's groundedness (0-1).
|
|
31
|
+
|
|
32
|
+
Higher scores indicate better grounding in source text.
|
|
33
|
+
"""
|
|
34
|
+
if not source_text:
|
|
35
|
+
return 0.5 # Neutral score if no source text
|
|
36
|
+
|
|
37
|
+
score = 0.0
|
|
38
|
+
weights_sum = 0.0
|
|
39
|
+
|
|
40
|
+
# Check subject appears in source (weight: 0.3)
|
|
41
|
+
subject_found = self._text_appears_in(statement.subject.text, source_text)
|
|
42
|
+
score += 0.3 * (1.0 if subject_found else 0.0)
|
|
43
|
+
weights_sum += 0.3
|
|
44
|
+
|
|
45
|
+
# Check object appears in source (weight: 0.3)
|
|
46
|
+
object_found = self._text_appears_in(statement.object.text, source_text)
|
|
47
|
+
score += 0.3 * (1.0 if object_found else 0.0)
|
|
48
|
+
weights_sum += 0.3
|
|
49
|
+
|
|
50
|
+
# Check predicate has lexical trigger (weight: 0.2)
|
|
51
|
+
predicate_grounded = self._predicate_has_trigger(statement.predicate, source_text)
|
|
52
|
+
score += 0.2 * (1.0 if predicate_grounded else 0.0)
|
|
53
|
+
weights_sum += 0.2
|
|
54
|
+
|
|
55
|
+
# Check proximity - subject and object in same/nearby region (weight: 0.2)
|
|
56
|
+
if subject_found and object_found:
|
|
57
|
+
proximity_score = self._compute_proximity(
|
|
58
|
+
statement.subject.text,
|
|
59
|
+
statement.object.text,
|
|
60
|
+
source_text
|
|
61
|
+
)
|
|
62
|
+
score += 0.2 * proximity_score
|
|
63
|
+
weights_sum += 0.2
|
|
64
|
+
|
|
65
|
+
return score / weights_sum if weights_sum > 0 else 0.0
|
|
66
|
+
|
|
67
|
+
def find_evidence_span(
|
|
68
|
+
self,
|
|
69
|
+
statement: Statement,
|
|
70
|
+
source_text: str
|
|
71
|
+
) -> Optional[tuple[int, int]]:
|
|
72
|
+
"""
|
|
73
|
+
Find character offsets where the triple is grounded in source text.
|
|
74
|
+
|
|
75
|
+
Returns (start, end) tuple or None if not found.
|
|
76
|
+
"""
|
|
77
|
+
if not source_text:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
# If statement has source_text field, try to find it
|
|
81
|
+
if statement.source_text:
|
|
82
|
+
pos = source_text.lower().find(statement.source_text.lower())
|
|
83
|
+
if pos >= 0:
|
|
84
|
+
return (pos, pos + len(statement.source_text))
|
|
85
|
+
|
|
86
|
+
# Otherwise, find the region containing both subject and object
|
|
87
|
+
subject_lower = statement.subject.text.lower()
|
|
88
|
+
object_lower = statement.object.text.lower()
|
|
89
|
+
source_lower = source_text.lower()
|
|
90
|
+
|
|
91
|
+
subj_pos = source_lower.find(subject_lower)
|
|
92
|
+
obj_pos = source_lower.find(object_lower)
|
|
93
|
+
|
|
94
|
+
if subj_pos >= 0 and obj_pos >= 0:
|
|
95
|
+
start = min(subj_pos, obj_pos)
|
|
96
|
+
end = max(
|
|
97
|
+
subj_pos + len(subject_lower),
|
|
98
|
+
obj_pos + len(object_lower)
|
|
99
|
+
)
|
|
100
|
+
# Extend to sentence boundaries
|
|
101
|
+
start, end = self._extend_to_sentence(source_text, start, end)
|
|
102
|
+
return (start, end)
|
|
103
|
+
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
def _text_appears_in(self, text: str, source: str) -> bool:
|
|
107
|
+
"""Check if text appears in source (case-insensitive)."""
|
|
108
|
+
return text.lower() in source.lower()
|
|
109
|
+
|
|
110
|
+
def _predicate_has_trigger(self, predicate: str, source: str) -> bool:
|
|
111
|
+
"""Check if predicate has a lexical trigger in source."""
|
|
112
|
+
# Extract main verb/word from predicate
|
|
113
|
+
words = predicate.lower().split()
|
|
114
|
+
source_lower = source.lower()
|
|
115
|
+
|
|
116
|
+
# Check if any predicate word appears in source
|
|
117
|
+
for word in words:
|
|
118
|
+
if len(word) > 2 and word in source_lower:
|
|
119
|
+
return True
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
def _compute_proximity(
|
|
123
|
+
self,
|
|
124
|
+
subject_text: str,
|
|
125
|
+
object_text: str,
|
|
126
|
+
source: str
|
|
127
|
+
) -> float:
|
|
128
|
+
"""
|
|
129
|
+
Compute proximity score (0-1) based on distance between subject and object.
|
|
130
|
+
|
|
131
|
+
Returns 1.0 if same sentence, decreasing with distance.
|
|
132
|
+
"""
|
|
133
|
+
source_lower = source.lower()
|
|
134
|
+
subj_pos = source_lower.find(subject_text.lower())
|
|
135
|
+
obj_pos = source_lower.find(object_text.lower())
|
|
136
|
+
|
|
137
|
+
if subj_pos < 0 or obj_pos < 0:
|
|
138
|
+
return 0.0
|
|
139
|
+
|
|
140
|
+
# Check if in same sentence
|
|
141
|
+
start = min(subj_pos, obj_pos)
|
|
142
|
+
end = max(subj_pos, obj_pos)
|
|
143
|
+
region = source[start:end]
|
|
144
|
+
|
|
145
|
+
# If no sentence boundary between them, high proximity
|
|
146
|
+
if '.' not in region and '!' not in region and '?' not in region:
|
|
147
|
+
return 1.0
|
|
148
|
+
|
|
149
|
+
# Otherwise, score decreases with distance
|
|
150
|
+
# Assume ~100 chars per sentence on average
|
|
151
|
+
sentence_distance = region.count('.') + region.count('!') + region.count('?')
|
|
152
|
+
return max(0.0, 1.0 - (sentence_distance * 0.2))
|
|
153
|
+
|
|
154
|
+
def _extend_to_sentence(
|
|
155
|
+
self,
|
|
156
|
+
source: str,
|
|
157
|
+
start: int,
|
|
158
|
+
end: int
|
|
159
|
+
) -> tuple[int, int]:
|
|
160
|
+
"""Extend span to sentence boundaries."""
|
|
161
|
+
# Find sentence start
|
|
162
|
+
sentence_start = start
|
|
163
|
+
while sentence_start > 0:
|
|
164
|
+
char = source[sentence_start - 1]
|
|
165
|
+
if char in '.!?\n':
|
|
166
|
+
break
|
|
167
|
+
sentence_start -= 1
|
|
168
|
+
|
|
169
|
+
# Find sentence end
|
|
170
|
+
sentence_end = end
|
|
171
|
+
while sentence_end < len(source):
|
|
172
|
+
char = source[sentence_end]
|
|
173
|
+
if char in '.!?\n':
|
|
174
|
+
sentence_end += 1
|
|
175
|
+
break
|
|
176
|
+
sentence_end += 1
|
|
177
|
+
|
|
178
|
+
return (sentence_start, sentence_end)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class BeamScorer:
|
|
182
|
+
"""
|
|
183
|
+
Score and select/merge beams based on quality metrics.
|
|
184
|
+
|
|
185
|
+
Implements the scoring function:
|
|
186
|
+
Score = Σ quality(t) + β×Coverage - γ×Redundancy - δ×Length
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
config: Optional[ScoringConfig] = None,
|
|
192
|
+
triple_scorer: Optional[TripleScorer] = None
|
|
193
|
+
):
|
|
194
|
+
self.config = config or ScoringConfig()
|
|
195
|
+
self.triple_scorer = triple_scorer or TripleScorer(config)
|
|
196
|
+
|
|
197
|
+
def score_beam(
|
|
198
|
+
self,
|
|
199
|
+
statements: list[Statement],
|
|
200
|
+
source_text: str
|
|
201
|
+
) -> float:
|
|
202
|
+
"""
|
|
203
|
+
Compute beam score using the quality formula.
|
|
204
|
+
|
|
205
|
+
Score = Σ quality(t) + β×Coverage - γ×Redundancy - δ×Length
|
|
206
|
+
"""
|
|
207
|
+
if not statements:
|
|
208
|
+
return 0.0
|
|
209
|
+
|
|
210
|
+
# Sum of quality scores
|
|
211
|
+
quality_sum = sum(
|
|
212
|
+
(stmt.confidence_score or self.triple_scorer.score_triple(stmt, source_text))
|
|
213
|
+
for stmt in statements
|
|
214
|
+
)
|
|
215
|
+
quality_term = self.config.quality_weight * quality_sum
|
|
216
|
+
|
|
217
|
+
# Coverage bonus
|
|
218
|
+
coverage = self.compute_coverage(statements, source_text)
|
|
219
|
+
coverage_term = self.config.coverage_weight * coverage
|
|
220
|
+
|
|
221
|
+
# Redundancy penalty
|
|
222
|
+
redundancy = self.compute_redundancy(statements)
|
|
223
|
+
redundancy_term = self.config.redundancy_penalty * redundancy
|
|
224
|
+
|
|
225
|
+
# Length penalty (normalized by statement count)
|
|
226
|
+
length = len(statements)
|
|
227
|
+
length_term = self.config.length_penalty * (length / 10.0) # Normalize
|
|
228
|
+
|
|
229
|
+
return quality_term + coverage_term - redundancy_term - length_term
|
|
230
|
+
|
|
231
|
+
def compute_coverage(
|
|
232
|
+
self,
|
|
233
|
+
statements: list[Statement],
|
|
234
|
+
source_text: str
|
|
235
|
+
) -> float:
|
|
236
|
+
"""
|
|
237
|
+
Compute coverage: % of source text tokens explained by evidence spans.
|
|
238
|
+
"""
|
|
239
|
+
if not source_text or not statements:
|
|
240
|
+
return 0.0
|
|
241
|
+
|
|
242
|
+
# Track which character positions are covered
|
|
243
|
+
covered = set()
|
|
244
|
+
|
|
245
|
+
for stmt in statements:
|
|
246
|
+
span = stmt.evidence_span
|
|
247
|
+
if span is None:
|
|
248
|
+
span = self.triple_scorer.find_evidence_span(stmt, source_text)
|
|
249
|
+
|
|
250
|
+
if span:
|
|
251
|
+
for i in range(span[0], min(span[1], len(source_text))):
|
|
252
|
+
covered.add(i)
|
|
253
|
+
|
|
254
|
+
# Calculate coverage as percentage of non-whitespace characters
|
|
255
|
+
content_chars = sum(1 for c in source_text if not c.isspace())
|
|
256
|
+
covered_content = sum(1 for i in covered if not source_text[i].isspace())
|
|
257
|
+
|
|
258
|
+
return covered_content / content_chars if content_chars > 0 else 0.0
|
|
259
|
+
|
|
260
|
+
def compute_redundancy(self, statements: list[Statement]) -> float:
|
|
261
|
+
"""
|
|
262
|
+
Compute redundancy penalty for near-duplicate triples.
|
|
263
|
+
|
|
264
|
+
Only counts exact duplicates (same subject, predicate, and object).
|
|
265
|
+
Note: Same subject+predicate with different objects is NOT redundant,
|
|
266
|
+
as it represents distinct relationships (e.g., "Apple announced iPhone and iPad").
|
|
267
|
+
"""
|
|
268
|
+
if len(statements) < 2:
|
|
269
|
+
return 0.0
|
|
270
|
+
|
|
271
|
+
redundant_pairs = 0
|
|
272
|
+
total_pairs = 0
|
|
273
|
+
|
|
274
|
+
for i, stmt1 in enumerate(statements):
|
|
275
|
+
for stmt2 in statements[i + 1:]:
|
|
276
|
+
total_pairs += 1
|
|
277
|
+
|
|
278
|
+
# Only count exact duplicates (same subject, predicate, AND object)
|
|
279
|
+
if (stmt1.subject.text.lower() == stmt2.subject.text.lower() and
|
|
280
|
+
stmt1.predicate.lower() == stmt2.predicate.lower() and
|
|
281
|
+
stmt1.object.text.lower() == stmt2.object.text.lower()):
|
|
282
|
+
redundant_pairs += 1
|
|
283
|
+
|
|
284
|
+
return redundant_pairs / total_pairs if total_pairs > 0 else 0.0
|
|
285
|
+
|
|
286
|
+
def score_and_rank_statements(
|
|
287
|
+
self,
|
|
288
|
+
statements: list[Statement],
|
|
289
|
+
source_text: str
|
|
290
|
+
) -> list[Statement]:
|
|
291
|
+
"""
|
|
292
|
+
Score each statement and return sorted by confidence (descending).
|
|
293
|
+
"""
|
|
294
|
+
for stmt in statements:
|
|
295
|
+
if stmt.confidence_score is None:
|
|
296
|
+
stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
|
|
297
|
+
if stmt.evidence_span is None:
|
|
298
|
+
stmt.evidence_span = self.triple_scorer.find_evidence_span(stmt, source_text)
|
|
299
|
+
|
|
300
|
+
return sorted(statements, key=lambda s: s.confidence_score or 0.0, reverse=True)
|
|
301
|
+
|
|
302
|
+
def select_best_beam(
|
|
303
|
+
self,
|
|
304
|
+
candidates: list[list[Statement]],
|
|
305
|
+
source_text: str
|
|
306
|
+
) -> list[Statement]:
|
|
307
|
+
"""
|
|
308
|
+
Select the highest-scoring beam from candidates.
|
|
309
|
+
"""
|
|
310
|
+
if not candidates:
|
|
311
|
+
return []
|
|
312
|
+
|
|
313
|
+
# Score each candidate and add confidence scores
|
|
314
|
+
scored_candidates = []
|
|
315
|
+
for beam in candidates:
|
|
316
|
+
# Score individual statements
|
|
317
|
+
for stmt in beam:
|
|
318
|
+
if stmt.confidence_score is None:
|
|
319
|
+
stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
|
|
320
|
+
if stmt.evidence_span is None:
|
|
321
|
+
stmt.evidence_span = self.triple_scorer.find_evidence_span(stmt, source_text)
|
|
322
|
+
|
|
323
|
+
beam_score = self.score_beam(beam, source_text)
|
|
324
|
+
scored_candidates.append((beam_score, beam))
|
|
325
|
+
|
|
326
|
+
# Select best
|
|
327
|
+
scored_candidates.sort(key=lambda x: x[0], reverse=True)
|
|
328
|
+
return scored_candidates[0][1]
|
|
329
|
+
|
|
330
|
+
def merge_beams(
|
|
331
|
+
self,
|
|
332
|
+
candidates: list[list[Statement]],
|
|
333
|
+
source_text: str,
|
|
334
|
+
top_n: Optional[int] = None
|
|
335
|
+
) -> list[Statement]:
|
|
336
|
+
"""
|
|
337
|
+
Merge top-N beams, keeping high-quality unique triples.
|
|
338
|
+
|
|
339
|
+
1. Score all beams
|
|
340
|
+
2. Take top N beams
|
|
341
|
+
3. Pool all triples
|
|
342
|
+
4. Filter by confidence threshold
|
|
343
|
+
5. Deduplicate (keeping highest confidence)
|
|
344
|
+
6. Resolve conflicts
|
|
345
|
+
"""
|
|
346
|
+
if not candidates:
|
|
347
|
+
return []
|
|
348
|
+
|
|
349
|
+
top_n = top_n or self.config.merge_top_n
|
|
350
|
+
|
|
351
|
+
# Score each beam
|
|
352
|
+
scored_beams = []
|
|
353
|
+
for beam in candidates:
|
|
354
|
+
for stmt in beam:
|
|
355
|
+
if stmt.confidence_score is None:
|
|
356
|
+
stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
|
|
357
|
+
if stmt.evidence_span is None:
|
|
358
|
+
stmt.evidence_span = self.triple_scorer.find_evidence_span(stmt, source_text)
|
|
359
|
+
|
|
360
|
+
beam_score = self.score_beam(beam, source_text)
|
|
361
|
+
scored_beams.append((beam_score, beam))
|
|
362
|
+
|
|
363
|
+
# Sort and take top N
|
|
364
|
+
scored_beams.sort(key=lambda x: x[0], reverse=True)
|
|
365
|
+
top_beams = [beam for _, beam in scored_beams[:top_n]]
|
|
366
|
+
|
|
367
|
+
# Pool all triples
|
|
368
|
+
all_statements: list[Statement] = []
|
|
369
|
+
for beam in top_beams:
|
|
370
|
+
all_statements.extend(beam)
|
|
371
|
+
|
|
372
|
+
# Filter by confidence threshold
|
|
373
|
+
min_conf = self.config.min_confidence
|
|
374
|
+
filtered = [s for s in all_statements if (s.confidence_score or 0) >= min_conf]
|
|
375
|
+
|
|
376
|
+
# Filter out statements where source_text doesn't support the predicate
|
|
377
|
+
# This catches model hallucinations where predicate doesn't match the evidence
|
|
378
|
+
consistent = [
|
|
379
|
+
s for s in filtered
|
|
380
|
+
if self._source_text_supports_predicate(s)
|
|
381
|
+
]
|
|
382
|
+
|
|
383
|
+
# Deduplicate - keep highest confidence for each (subject, predicate, object)
|
|
384
|
+
# Note: Same subject+predicate with different objects is valid (e.g., "Apple announced X and Y")
|
|
385
|
+
seen: dict[tuple[str, str, str], Statement] = {}
|
|
386
|
+
for stmt in consistent:
|
|
387
|
+
key = (
|
|
388
|
+
stmt.subject.text.lower(),
|
|
389
|
+
stmt.predicate.lower(),
|
|
390
|
+
stmt.object.text.lower()
|
|
391
|
+
)
|
|
392
|
+
if key not in seen or (stmt.confidence_score or 0) > (seen[key].confidence_score or 0):
|
|
393
|
+
seen[key] = stmt
|
|
394
|
+
|
|
395
|
+
return list(seen.values())
|
|
396
|
+
|
|
397
|
+
def _source_text_supports_predicate(self, stmt: Statement) -> bool:
|
|
398
|
+
"""
|
|
399
|
+
Check if a statement's source_text contains a lexical trigger for its predicate.
|
|
400
|
+
|
|
401
|
+
Returns True if:
|
|
402
|
+
- source_text is None (no requirement to check)
|
|
403
|
+
- source_text contains at least one significant word from the predicate
|
|
404
|
+
|
|
405
|
+
Returns False if:
|
|
406
|
+
- source_text is set but contains no words from the predicate
|
|
407
|
+
"""
|
|
408
|
+
if not stmt.source_text:
|
|
409
|
+
return True # No source_text to check
|
|
410
|
+
|
|
411
|
+
predicate_words = stmt.predicate.lower().split()
|
|
412
|
+
source_lower = stmt.source_text.lower()
|
|
413
|
+
|
|
414
|
+
# Check if any significant predicate word appears in source_text
|
|
415
|
+
for word in predicate_words:
|
|
416
|
+
if len(word) > 2 and word in source_lower:
|
|
417
|
+
return True
|
|
418
|
+
|
|
419
|
+
return False
|