corp-extractor 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,419 @@
1
+ """
2
+ Scoring module for statement extraction quality assessment.
3
+
4
+ Provides:
5
+ - TripleScorer: Score individual triples for groundedness
6
+ - BeamScorer: Score and select/merge beams based on quality metrics
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from .models import ScoringConfig, Statement
12
+
13
+
14
+ class TripleScorer:
15
+ """
16
+ Score individual triples for groundedness in source text.
17
+
18
+ Groundedness is measured by checking:
19
+ - Subject text appears in source
20
+ - Object text appears in source
21
+ - Subject and object are in proximity (same/nearby sentences)
22
+ - Evidence span exists and is valid
23
+ """
24
+
25
+ def __init__(self, config: Optional[ScoringConfig] = None):
26
+ self.config = config or ScoringConfig()
27
+
28
+ def score_triple(self, statement: Statement, source_text: str) -> float:
29
+ """
30
+ Score a triple's groundedness (0-1).
31
+
32
+ Higher scores indicate better grounding in source text.
33
+ """
34
+ if not source_text:
35
+ return 0.5 # Neutral score if no source text
36
+
37
+ score = 0.0
38
+ weights_sum = 0.0
39
+
40
+ # Check subject appears in source (weight: 0.3)
41
+ subject_found = self._text_appears_in(statement.subject.text, source_text)
42
+ score += 0.3 * (1.0 if subject_found else 0.0)
43
+ weights_sum += 0.3
44
+
45
+ # Check object appears in source (weight: 0.3)
46
+ object_found = self._text_appears_in(statement.object.text, source_text)
47
+ score += 0.3 * (1.0 if object_found else 0.0)
48
+ weights_sum += 0.3
49
+
50
+ # Check predicate has lexical trigger (weight: 0.2)
51
+ predicate_grounded = self._predicate_has_trigger(statement.predicate, source_text)
52
+ score += 0.2 * (1.0 if predicate_grounded else 0.0)
53
+ weights_sum += 0.2
54
+
55
+ # Check proximity - subject and object in same/nearby region (weight: 0.2)
56
+ if subject_found and object_found:
57
+ proximity_score = self._compute_proximity(
58
+ statement.subject.text,
59
+ statement.object.text,
60
+ source_text
61
+ )
62
+ score += 0.2 * proximity_score
63
+ weights_sum += 0.2
64
+
65
+ return score / weights_sum if weights_sum > 0 else 0.0
66
+
67
+ def find_evidence_span(
68
+ self,
69
+ statement: Statement,
70
+ source_text: str
71
+ ) -> Optional[tuple[int, int]]:
72
+ """
73
+ Find character offsets where the triple is grounded in source text.
74
+
75
+ Returns (start, end) tuple or None if not found.
76
+ """
77
+ if not source_text:
78
+ return None
79
+
80
+ # If statement has source_text field, try to find it
81
+ if statement.source_text:
82
+ pos = source_text.lower().find(statement.source_text.lower())
83
+ if pos >= 0:
84
+ return (pos, pos + len(statement.source_text))
85
+
86
+ # Otherwise, find the region containing both subject and object
87
+ subject_lower = statement.subject.text.lower()
88
+ object_lower = statement.object.text.lower()
89
+ source_lower = source_text.lower()
90
+
91
+ subj_pos = source_lower.find(subject_lower)
92
+ obj_pos = source_lower.find(object_lower)
93
+
94
+ if subj_pos >= 0 and obj_pos >= 0:
95
+ start = min(subj_pos, obj_pos)
96
+ end = max(
97
+ subj_pos + len(subject_lower),
98
+ obj_pos + len(object_lower)
99
+ )
100
+ # Extend to sentence boundaries
101
+ start, end = self._extend_to_sentence(source_text, start, end)
102
+ return (start, end)
103
+
104
+ return None
105
+
106
+ def _text_appears_in(self, text: str, source: str) -> bool:
107
+ """Check if text appears in source (case-insensitive)."""
108
+ return text.lower() in source.lower()
109
+
110
+ def _predicate_has_trigger(self, predicate: str, source: str) -> bool:
111
+ """Check if predicate has a lexical trigger in source."""
112
+ # Extract main verb/word from predicate
113
+ words = predicate.lower().split()
114
+ source_lower = source.lower()
115
+
116
+ # Check if any predicate word appears in source
117
+ for word in words:
118
+ if len(word) > 2 and word in source_lower:
119
+ return True
120
+ return False
121
+
122
+ def _compute_proximity(
123
+ self,
124
+ subject_text: str,
125
+ object_text: str,
126
+ source: str
127
+ ) -> float:
128
+ """
129
+ Compute proximity score (0-1) based on distance between subject and object.
130
+
131
+ Returns 1.0 if same sentence, decreasing with distance.
132
+ """
133
+ source_lower = source.lower()
134
+ subj_pos = source_lower.find(subject_text.lower())
135
+ obj_pos = source_lower.find(object_text.lower())
136
+
137
+ if subj_pos < 0 or obj_pos < 0:
138
+ return 0.0
139
+
140
+ # Check if in same sentence
141
+ start = min(subj_pos, obj_pos)
142
+ end = max(subj_pos, obj_pos)
143
+ region = source[start:end]
144
+
145
+ # If no sentence boundary between them, high proximity
146
+ if '.' not in region and '!' not in region and '?' not in region:
147
+ return 1.0
148
+
149
+ # Otherwise, score decreases with distance
150
+ # Assume ~100 chars per sentence on average
151
+ sentence_distance = region.count('.') + region.count('!') + region.count('?')
152
+ return max(0.0, 1.0 - (sentence_distance * 0.2))
153
+
154
+ def _extend_to_sentence(
155
+ self,
156
+ source: str,
157
+ start: int,
158
+ end: int
159
+ ) -> tuple[int, int]:
160
+ """Extend span to sentence boundaries."""
161
+ # Find sentence start
162
+ sentence_start = start
163
+ while sentence_start > 0:
164
+ char = source[sentence_start - 1]
165
+ if char in '.!?\n':
166
+ break
167
+ sentence_start -= 1
168
+
169
+ # Find sentence end
170
+ sentence_end = end
171
+ while sentence_end < len(source):
172
+ char = source[sentence_end]
173
+ if char in '.!?\n':
174
+ sentence_end += 1
175
+ break
176
+ sentence_end += 1
177
+
178
+ return (sentence_start, sentence_end)
179
+
180
+
181
+ class BeamScorer:
182
+ """
183
+ Score and select/merge beams based on quality metrics.
184
+
185
+ Implements the scoring function:
186
+ Score = Σ quality(t) + β×Coverage - γ×Redundancy - δ×Length
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ config: Optional[ScoringConfig] = None,
192
+ triple_scorer: Optional[TripleScorer] = None
193
+ ):
194
+ self.config = config or ScoringConfig()
195
+ self.triple_scorer = triple_scorer or TripleScorer(config)
196
+
197
+ def score_beam(
198
+ self,
199
+ statements: list[Statement],
200
+ source_text: str
201
+ ) -> float:
202
+ """
203
+ Compute beam score using the quality formula.
204
+
205
+ Score = Σ quality(t) + β×Coverage - γ×Redundancy - δ×Length
206
+ """
207
+ if not statements:
208
+ return 0.0
209
+
210
+ # Sum of quality scores
211
+ quality_sum = sum(
212
+ (stmt.confidence_score or self.triple_scorer.score_triple(stmt, source_text))
213
+ for stmt in statements
214
+ )
215
+ quality_term = self.config.quality_weight * quality_sum
216
+
217
+ # Coverage bonus
218
+ coverage = self.compute_coverage(statements, source_text)
219
+ coverage_term = self.config.coverage_weight * coverage
220
+
221
+ # Redundancy penalty
222
+ redundancy = self.compute_redundancy(statements)
223
+ redundancy_term = self.config.redundancy_penalty * redundancy
224
+
225
+ # Length penalty (normalized by statement count)
226
+ length = len(statements)
227
+ length_term = self.config.length_penalty * (length / 10.0) # Normalize
228
+
229
+ return quality_term + coverage_term - redundancy_term - length_term
230
+
231
+ def compute_coverage(
232
+ self,
233
+ statements: list[Statement],
234
+ source_text: str
235
+ ) -> float:
236
+ """
237
+ Compute coverage: % of source text tokens explained by evidence spans.
238
+ """
239
+ if not source_text or not statements:
240
+ return 0.0
241
+
242
+ # Track which character positions are covered
243
+ covered = set()
244
+
245
+ for stmt in statements:
246
+ span = stmt.evidence_span
247
+ if span is None:
248
+ span = self.triple_scorer.find_evidence_span(stmt, source_text)
249
+
250
+ if span:
251
+ for i in range(span[0], min(span[1], len(source_text))):
252
+ covered.add(i)
253
+
254
+ # Calculate coverage as percentage of non-whitespace characters
255
+ content_chars = sum(1 for c in source_text if not c.isspace())
256
+ covered_content = sum(1 for i in covered if not source_text[i].isspace())
257
+
258
+ return covered_content / content_chars if content_chars > 0 else 0.0
259
+
260
+ def compute_redundancy(self, statements: list[Statement]) -> float:
261
+ """
262
+ Compute redundancy penalty for near-duplicate triples.
263
+
264
+ Only counts exact duplicates (same subject, predicate, and object).
265
+ Note: Same subject+predicate with different objects is NOT redundant,
266
+ as it represents distinct relationships (e.g., "Apple announced iPhone and iPad").
267
+ """
268
+ if len(statements) < 2:
269
+ return 0.0
270
+
271
+ redundant_pairs = 0
272
+ total_pairs = 0
273
+
274
+ for i, stmt1 in enumerate(statements):
275
+ for stmt2 in statements[i + 1:]:
276
+ total_pairs += 1
277
+
278
+ # Only count exact duplicates (same subject, predicate, AND object)
279
+ if (stmt1.subject.text.lower() == stmt2.subject.text.lower() and
280
+ stmt1.predicate.lower() == stmt2.predicate.lower() and
281
+ stmt1.object.text.lower() == stmt2.object.text.lower()):
282
+ redundant_pairs += 1
283
+
284
+ return redundant_pairs / total_pairs if total_pairs > 0 else 0.0
285
+
286
+ def score_and_rank_statements(
287
+ self,
288
+ statements: list[Statement],
289
+ source_text: str
290
+ ) -> list[Statement]:
291
+ """
292
+ Score each statement and return sorted by confidence (descending).
293
+ """
294
+ for stmt in statements:
295
+ if stmt.confidence_score is None:
296
+ stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
297
+ if stmt.evidence_span is None:
298
+ stmt.evidence_span = self.triple_scorer.find_evidence_span(stmt, source_text)
299
+
300
+ return sorted(statements, key=lambda s: s.confidence_score or 0.0, reverse=True)
301
+
302
+ def select_best_beam(
303
+ self,
304
+ candidates: list[list[Statement]],
305
+ source_text: str
306
+ ) -> list[Statement]:
307
+ """
308
+ Select the highest-scoring beam from candidates.
309
+ """
310
+ if not candidates:
311
+ return []
312
+
313
+ # Score each candidate and add confidence scores
314
+ scored_candidates = []
315
+ for beam in candidates:
316
+ # Score individual statements
317
+ for stmt in beam:
318
+ if stmt.confidence_score is None:
319
+ stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
320
+ if stmt.evidence_span is None:
321
+ stmt.evidence_span = self.triple_scorer.find_evidence_span(stmt, source_text)
322
+
323
+ beam_score = self.score_beam(beam, source_text)
324
+ scored_candidates.append((beam_score, beam))
325
+
326
+ # Select best
327
+ scored_candidates.sort(key=lambda x: x[0], reverse=True)
328
+ return scored_candidates[0][1]
329
+
330
+ def merge_beams(
331
+ self,
332
+ candidates: list[list[Statement]],
333
+ source_text: str,
334
+ top_n: Optional[int] = None
335
+ ) -> list[Statement]:
336
+ """
337
+ Merge top-N beams, keeping high-quality unique triples.
338
+
339
+ 1. Score all beams
340
+ 2. Take top N beams
341
+ 3. Pool all triples
342
+ 4. Filter by confidence threshold
343
+ 5. Deduplicate (keeping highest confidence)
344
+ 6. Resolve conflicts
345
+ """
346
+ if not candidates:
347
+ return []
348
+
349
+ top_n = top_n or self.config.merge_top_n
350
+
351
+ # Score each beam
352
+ scored_beams = []
353
+ for beam in candidates:
354
+ for stmt in beam:
355
+ if stmt.confidence_score is None:
356
+ stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
357
+ if stmt.evidence_span is None:
358
+ stmt.evidence_span = self.triple_scorer.find_evidence_span(stmt, source_text)
359
+
360
+ beam_score = self.score_beam(beam, source_text)
361
+ scored_beams.append((beam_score, beam))
362
+
363
+ # Sort and take top N
364
+ scored_beams.sort(key=lambda x: x[0], reverse=True)
365
+ top_beams = [beam for _, beam in scored_beams[:top_n]]
366
+
367
+ # Pool all triples
368
+ all_statements: list[Statement] = []
369
+ for beam in top_beams:
370
+ all_statements.extend(beam)
371
+
372
+ # Filter by confidence threshold
373
+ min_conf = self.config.min_confidence
374
+ filtered = [s for s in all_statements if (s.confidence_score or 0) >= min_conf]
375
+
376
+ # Filter out statements where source_text doesn't support the predicate
377
+ # This catches model hallucinations where predicate doesn't match the evidence
378
+ consistent = [
379
+ s for s in filtered
380
+ if self._source_text_supports_predicate(s)
381
+ ]
382
+
383
+ # Deduplicate - keep highest confidence for each (subject, predicate, object)
384
+ # Note: Same subject+predicate with different objects is valid (e.g., "Apple announced X and Y")
385
+ seen: dict[tuple[str, str, str], Statement] = {}
386
+ for stmt in consistent:
387
+ key = (
388
+ stmt.subject.text.lower(),
389
+ stmt.predicate.lower(),
390
+ stmt.object.text.lower()
391
+ )
392
+ if key not in seen or (stmt.confidence_score or 0) > (seen[key].confidence_score or 0):
393
+ seen[key] = stmt
394
+
395
+ return list(seen.values())
396
+
397
+ def _source_text_supports_predicate(self, stmt: Statement) -> bool:
398
+ """
399
+ Check if a statement's source_text contains a lexical trigger for its predicate.
400
+
401
+ Returns True if:
402
+ - source_text is None (no requirement to check)
403
+ - source_text contains at least one significant word from the predicate
404
+
405
+ Returns False if:
406
+ - source_text is set but contains no words from the predicate
407
+ """
408
+ if not stmt.source_text:
409
+ return True # No source_text to check
410
+
411
+ predicate_words = stmt.predicate.lower().split()
412
+ source_lower = stmt.source_text.lower()
413
+
414
+ # Check if any significant predicate word appears in source_text
415
+ for word in predicate_words:
416
+ if len(word) > 2 and word in source_lower:
417
+ return True
418
+
419
+ return False