span-aligner 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- span_aligner/__init__.py +3 -3
- span_aligner/sentence_aligner.py +713 -0
- span_aligner/span_aligner.py +1178 -0
- span_aligner/span_projector.py +585 -0
- span_aligner-0.2.0.dist-info/METADATA +283 -0
- span_aligner-0.2.0.dist-info/RECORD +10 -0
- {span_aligner-0.1.2.dist-info → span_aligner-0.2.0.dist-info}/WHEEL +1 -1
- span_aligner-0.1.2.dist-info/METADATA +0 -169
- span_aligner-0.1.2.dist-info/RECORD +0 -7
- {span_aligner-0.1.2.dist-info → span_aligner-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {span_aligner-0.1.2.dist-info → span_aligner-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,585 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from math import log
|
|
4
|
+
from typing import List, Dict, Any, Tuple, Optional
|
|
5
|
+
from .sentence_aligner import EmbeddingProvider, SentenceAligner, TokenInfo, AlignmentResult, TransformerEmbeddingProvider
|
|
6
|
+
from .span_aligner import SpanAligner
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SpanProjector:
|
|
12
|
+
"""
|
|
13
|
+
Projector class to map annotations from tagged source text to projection target text.
|
|
14
|
+
|
|
15
|
+
Terminology:
|
|
16
|
+
- src / source: The text that HAS annotations/tags (e.g. the translation).
|
|
17
|
+
- tgt / target: The text to project annotations ONTO (e.g. the original).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
_aligner_cache = {}
|
|
21
|
+
MATCHING_METHODS_MAP = {"a": "inter", "m": "mwmf", "i": "itermax", "f": "fwd", "r": "rev",
|
|
22
|
+
"g": "greedy", "t": "threshold"}
|
|
23
|
+
MATCHING_METHODS_REV = {v: k for k, v in MATCHING_METHODS_MAP.items()}
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
src_lang: str = "en",
|
|
27
|
+
tgt_lang: str = "nl",
|
|
28
|
+
matching_method: str = "mwmf",
|
|
29
|
+
token_type: str = "bpe",
|
|
30
|
+
embedding_provider: Optional[EmbeddingProvider] = None,
|
|
31
|
+
model: str = "xlmr",
|
|
32
|
+
device: str = None,
|
|
33
|
+
layer: int = 8,
|
|
34
|
+
verbose: bool = False):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the projector.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
src_lang: Language of the tagged source text
|
|
40
|
+
tgt_lang: Language of the projection target text
|
|
41
|
+
matching_method: Matching method (mwmf, itermax, inter, greedy, threshold)
|
|
42
|
+
token_type: Tokenization type ("bpe" or "word")
|
|
43
|
+
embedding_provider: Optional custom EmbeddingProvider instance
|
|
44
|
+
model: Model name for default TransformerEmbeddingProvider
|
|
45
|
+
device: Device for computation (None = auto-detect)
|
|
46
|
+
layer: Transformer layer for embeddings
|
|
47
|
+
verbose: Whether to enable verbose logging
|
|
48
|
+
"""
|
|
49
|
+
self.src_lang = src_lang
|
|
50
|
+
self.tgt_lang = tgt_lang
|
|
51
|
+
self.matching_method = matching_method
|
|
52
|
+
self.token_type = token_type
|
|
53
|
+
self.verbose = verbose
|
|
54
|
+
|
|
55
|
+
# Auto-detect device
|
|
56
|
+
if device is None:
|
|
57
|
+
import torch
|
|
58
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
59
|
+
self.device = device
|
|
60
|
+
|
|
61
|
+
# Create or use provided embedding provider
|
|
62
|
+
if embedding_provider is not None:
|
|
63
|
+
self.embed_provider = embedding_provider
|
|
64
|
+
else:
|
|
65
|
+
self.embed_provider = TransformerEmbeddingProvider(model=model, device=device, layer=layer)
|
|
66
|
+
|
|
67
|
+
# Initialize aligner - this is the single source for tokenization and alignment
|
|
68
|
+
self.aligner = self._get_cached_aligner()
|
|
69
|
+
|
|
70
|
+
def _get_cached_aligner(self) -> SentenceAligner:
|
|
71
|
+
"""Get or create cached SentenceAligner instance."""
|
|
72
|
+
cache_key = f"aligner_{self.matching_method}_{self.token_type}_{id(self.embed_provider)}"
|
|
73
|
+
if cache_key not in self._aligner_cache:
|
|
74
|
+
match_key = self.MATCHING_METHODS_REV.get(self.matching_method, "m")
|
|
75
|
+
self._aligner_cache[cache_key] = SentenceAligner(
|
|
76
|
+
embedding_provider=self.embed_provider,
|
|
77
|
+
token_type=self.token_type,
|
|
78
|
+
matching_methods=match_key
|
|
79
|
+
)
|
|
80
|
+
return self._aligner_cache[cache_key]
|
|
81
|
+
|
|
82
|
+
# -------------------------------------------------------------------------
|
|
83
|
+
# Helper methods
|
|
84
|
+
# -------------------------------------------------------------------------
|
|
85
|
+
|
|
86
|
+
def _entropy(self, distribution: List[float]) -> float:
|
|
87
|
+
"""Calculate entropy of a probability distribution."""
|
|
88
|
+
total = sum(distribution)
|
|
89
|
+
if total <= 0:
|
|
90
|
+
return 0.0
|
|
91
|
+
normalized = [v / total for v in distribution if v > 0]
|
|
92
|
+
return -sum(v * log(v, 2) for v in normalized)
|
|
93
|
+
|
|
94
|
+
def _find_contiguous_clusters(self, indices: List[int], max_gap: int = 1) -> List[List[int]]:
|
|
95
|
+
"""Split token indices into contiguous groups, allowing small gaps."""
|
|
96
|
+
if not indices:
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
indices = sorted(set(indices))
|
|
100
|
+
clusters = []
|
|
101
|
+
current = [indices[0]]
|
|
102
|
+
|
|
103
|
+
for idx in indices[1:]:
|
|
104
|
+
if idx - current[-1] - 1 <= max_gap:
|
|
105
|
+
current.append(idx)
|
|
106
|
+
else:
|
|
107
|
+
clusters.append(current)
|
|
108
|
+
current = [idx]
|
|
109
|
+
clusters.append(current)
|
|
110
|
+
return clusters
|
|
111
|
+
|
|
112
|
+
def _fill_gaps_in_clusters(self, clusters: List[List[int]]) -> List[List[int]]:
|
|
113
|
+
"""Make each cluster a fully continuous, sorted span."""
|
|
114
|
+
filled_clusters = []
|
|
115
|
+
|
|
116
|
+
for cluster in clusters:
|
|
117
|
+
if not cluster:
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
cluster = sorted(set(cluster))
|
|
121
|
+
start, end = cluster[0], cluster[-1]
|
|
122
|
+
filled_clusters.append(list(range(start, end + 1)))
|
|
123
|
+
|
|
124
|
+
return filled_clusters
|
|
125
|
+
|
|
126
|
+
def _cluster_merge(self, clusters: List[List[int]], src_token_count: Optional[int] = None, base_gap: int = 2, mass_factor: float = 0.5) -> List[int]:
|
|
127
|
+
"""
|
|
128
|
+
Merges multiple alignment clusters into the single most likely contiguous sequence.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
clusters: List of index lists, e.g., [[1,2], [5,6,7], [20,21]].
|
|
132
|
+
src_token_count: Length of source span (used for sanity checking).
|
|
133
|
+
base_gap: Minimum gap always allowed (handles punctuation/particles).
|
|
134
|
+
mass_factor: Multiplier for dynamic tolerance.
|
|
135
|
+
Allowed Gap = base_gap + (min(len(A), len(B)) * mass_factor).
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A single flat list of indices representing the best alignment path.
|
|
139
|
+
"""
|
|
140
|
+
if not clusters:
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
# 1. Sort clusters by start index to enable linear chaining
|
|
144
|
+
sorted_clusters = sorted(clusters, key=lambda x: x[0])
|
|
145
|
+
|
|
146
|
+
# 2. Build Chains
|
|
147
|
+
# A 'chain' is a list of clusters that have successfully merged
|
|
148
|
+
chains: List[List[List[int]]] = []
|
|
149
|
+
current_chain = [sorted_clusters[0]]
|
|
150
|
+
|
|
151
|
+
for next_cluster in sorted_clusters[1:]:
|
|
152
|
+
prev_cluster = current_chain[-1]
|
|
153
|
+
|
|
154
|
+
# Calculate actual gap between the end of previous and start of next
|
|
155
|
+
# e.g., prev=[2], next=[5] -> gap is indices 3,4 -> size 2.
|
|
156
|
+
gap_size = next_cluster[0] - prev_cluster[-1] - 1
|
|
157
|
+
|
|
158
|
+
# Calculate Dynamic Tolerance
|
|
159
|
+
# We use the minimum mass because a weak link (size 1) breaks the chain easily.
|
|
160
|
+
connection_mass = min(len(prev_cluster), len(next_cluster))
|
|
161
|
+
allowed_gap = base_gap + int(connection_mass * mass_factor * 1.2**connection_mass)
|
|
162
|
+
if self.verbose:
|
|
163
|
+
print(f"Considering merge: Prev={prev_cluster}, Next={next_cluster}, Gap={gap_size}, Allowed={allowed_gap}")
|
|
164
|
+
|
|
165
|
+
if gap_size <= allowed_gap:
|
|
166
|
+
# Valid merge: Extend the current chain
|
|
167
|
+
current_chain.append(next_cluster)
|
|
168
|
+
else:
|
|
169
|
+
# Gap too large: Finalize current chain and start a new one
|
|
170
|
+
chains.append(current_chain)
|
|
171
|
+
current_chain = [next_cluster]
|
|
172
|
+
|
|
173
|
+
# Don't forget to save the final chain being built
|
|
174
|
+
chains.append(current_chain)
|
|
175
|
+
|
|
176
|
+
# 3. Select the Best Chain
|
|
177
|
+
best_chain_indices = []
|
|
178
|
+
best_score = (-1, -1) # (token_count, -span_length)
|
|
179
|
+
|
|
180
|
+
for chain in chains:
|
|
181
|
+
# Flatten the chain into a single list of indices
|
|
182
|
+
flat_indices = [idx for cluster in chain for idx in cluster]
|
|
183
|
+
|
|
184
|
+
if not flat_indices:
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
token_count = len(flat_indices)
|
|
188
|
+
span_length = flat_indices[-1] - flat_indices[0] + 1
|
|
189
|
+
|
|
190
|
+
# 4. Outlier Sanity Check (if source length is known)
|
|
191
|
+
# If the target span is massively larger than source (e.g. > 4x),
|
|
192
|
+
# it's likely a bad "spread" alignment. We penalize it heavily.
|
|
193
|
+
if src_token_count and span_length > (src_token_count * 4) + 6:
|
|
194
|
+
token_count = 0 # Disqualify this chain effectively
|
|
195
|
+
|
|
196
|
+
# 5. Scoring
|
|
197
|
+
# Priority 1: Maximize number of aligned tokens
|
|
198
|
+
# Priority 2: Minimize total span length (prefer compactness)
|
|
199
|
+
# We use negative span_length because we want to MAXIMIZE the tuple
|
|
200
|
+
score = (token_count, -span_length)
|
|
201
|
+
|
|
202
|
+
if score > best_score:
|
|
203
|
+
best_score = score
|
|
204
|
+
best_chain_indices = flat_indices
|
|
205
|
+
|
|
206
|
+
return best_chain_indices
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _map_span_to_target_tokens(self,
|
|
212
|
+
span: Dict[str, Any],
|
|
213
|
+
src_tokens: List[TokenInfo],
|
|
214
|
+
tgt_tokens: List[TokenInfo],
|
|
215
|
+
alignment: List[Tuple[int, int]],
|
|
216
|
+
max_gap: int = 1) -> Optional[List[int]]:
|
|
217
|
+
"""Map a span from source tokens (tagged) to target token indices (projected)."""
|
|
218
|
+
s_start, s_end = span.get("start"), span.get("end")
|
|
219
|
+
|
|
220
|
+
# Build source->target mapping
|
|
221
|
+
src_to_tgt = {}
|
|
222
|
+
for src_idx, tgt_idx in alignment:
|
|
223
|
+
src_to_tgt.setdefault(src_idx, []).append(tgt_idx)
|
|
224
|
+
|
|
225
|
+
# Find source tokens covered by this span
|
|
226
|
+
covered_src = [i for i, t in enumerate(src_tokens)
|
|
227
|
+
if max(s_start, t.start) < min(s_end, t.end)]
|
|
228
|
+
|
|
229
|
+
if not covered_src:
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
# Get corresponding target tokens
|
|
233
|
+
mapped_tgt = []
|
|
234
|
+
for s_idx in covered_src:
|
|
235
|
+
mapped_tgt.extend(src_to_tgt.get(s_idx, []))
|
|
236
|
+
|
|
237
|
+
if not mapped_tgt:
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
# Cluster and select largest
|
|
241
|
+
clusters = self._find_contiguous_clusters(mapped_tgt, max_gap=max_gap)
|
|
242
|
+
return max(clusters, key=len)
|
|
243
|
+
|
|
244
|
+
def validate_projection(self, text: str, spans: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
245
|
+
"""Validate projected spans for boundary integrity and text matching."""
|
|
246
|
+
issues = []
|
|
247
|
+
sorted_spans = sorted(spans, key=lambda s: (s["start"], s["end"]))
|
|
248
|
+
|
|
249
|
+
for i, span in enumerate(sorted_spans):
|
|
250
|
+
start, end = span["start"], span["end"]
|
|
251
|
+
|
|
252
|
+
if not (0 <= start <= end <= len(text)):
|
|
253
|
+
issues.append(f"Span out of bounds: {span}")
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
extracted = text[start:end]
|
|
257
|
+
if extracted.strip() != span.get("text", "").strip():
|
|
258
|
+
issues.append(f"Text mismatch at {i}: expected '{span.get('text')}', got '{extracted}'")
|
|
259
|
+
|
|
260
|
+
# Check overlaps
|
|
261
|
+
for i in range(len(sorted_spans) - 1):
|
|
262
|
+
curr, next_s = sorted_spans[i], sorted_spans[i + 1]
|
|
263
|
+
if curr["end"] > next_s["start"] and curr["end"] < next_s["end"]:
|
|
264
|
+
if curr.get("labels") != next_s.get("labels"):
|
|
265
|
+
issues.append(f"Partial overlap: {curr['text']} and {next_s['text']}")
|
|
266
|
+
|
|
267
|
+
return {"valid": len(issues) == 0, "issues": issues}
|
|
268
|
+
|
|
269
|
+
# -------------------------------------------------------------------------
|
|
270
|
+
# Main projection methods
|
|
271
|
+
# -------------------------------------------------------------------------
|
|
272
|
+
|
|
273
|
+
def project_spans(self,
|
|
274
|
+
src_text: str,
|
|
275
|
+
tgt_text: str,
|
|
276
|
+
src_spans: List[Dict[str, Any]],
|
|
277
|
+
max_gap: int = 1,
|
|
278
|
+
debugging: bool = False) -> List[Dict[str, Any]]:
|
|
279
|
+
"""
|
|
280
|
+
Project spans from source text (tagged) to target text.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
src_text: Tagged source text
|
|
284
|
+
tgt_text: Target text for projection
|
|
285
|
+
src_spans: List of span dicts on source text
|
|
286
|
+
max_gap: Maximum gap in token clustering
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
List of projected span dicts on target text
|
|
290
|
+
"""
|
|
291
|
+
# Use project_spans_with_debug_info but filter to standard output format
|
|
292
|
+
detailed = self.project_spans_with_debug_info(src_text, tgt_text, src_spans, max_gap)
|
|
293
|
+
if debugging:
|
|
294
|
+
print(self._visualize_projection(detailed))
|
|
295
|
+
|
|
296
|
+
projected = []
|
|
297
|
+
for item in detailed:
|
|
298
|
+
projected.append({
|
|
299
|
+
"start": item["start"],
|
|
300
|
+
"end": item["end"],
|
|
301
|
+
"text": item["text"],
|
|
302
|
+
"labels": item["labels"]
|
|
303
|
+
})
|
|
304
|
+
|
|
305
|
+
return projected
|
|
306
|
+
|
|
307
|
+
def project_spans_with_debug_info(self,
|
|
308
|
+
src_text: str,
|
|
309
|
+
tgt_text: str,
|
|
310
|
+
src_spans: List[Dict[str, Any]],
|
|
311
|
+
max_gap: int) -> List[Dict[str, Any]]:
|
|
312
|
+
|
|
313
|
+
"""Project spans with detailed debug information."""
|
|
314
|
+
if not src_text or not tgt_text or not src_spans:
|
|
315
|
+
return []
|
|
316
|
+
|
|
317
|
+
# Pre-compute embeddings for the full texts
|
|
318
|
+
if self.verbose:
|
|
319
|
+
print("Pre-computing embeddings for full texts...")
|
|
320
|
+
src_tokens, src_vectors = self.aligner.get_text_embeddings(src_text)
|
|
321
|
+
tgt_tokens, tgt_vectors = self.aligner.get_text_embeddings(tgt_text)
|
|
322
|
+
|
|
323
|
+
if self.verbose:
|
|
324
|
+
print(f"\n===============================")
|
|
325
|
+
print(f"Source tokens: {[t.text for t in src_tokens]}")
|
|
326
|
+
print(f"Target tokens: {[t.text for t in tgt_tokens]}")
|
|
327
|
+
print(f"Lengths - Source: {len(src_tokens)} tokens, Target: {len(tgt_tokens)} tokens")
|
|
328
|
+
print(f"lengths - Vectors: Source {src_vectors.shape}, Target {tgt_vectors.shape}")
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
projected = []
|
|
332
|
+
for span in src_spans:
|
|
333
|
+
if self.verbose:
|
|
334
|
+
print("\n=================================")
|
|
335
|
+
print("1.1 Span text:", span.get('text', ''))
|
|
336
|
+
print("1.2 word count:", len(span.get('text', '').split()))
|
|
337
|
+
|
|
338
|
+
# Use align_texts_partial_with_embeddings with char positions from the span
|
|
339
|
+
try:
|
|
340
|
+
result = self.aligner.align_texts_partial_with_embeddings(
|
|
341
|
+
src_tokens,
|
|
342
|
+
tgt_tokens,
|
|
343
|
+
src_vectors,
|
|
344
|
+
tgt_vectors,
|
|
345
|
+
src_char_start=span['start'],
|
|
346
|
+
src_char_end=span['end']
|
|
347
|
+
)
|
|
348
|
+
except ValueError as e:
|
|
349
|
+
if self.verbose:
|
|
350
|
+
print(f"Skipping span alignment due to error (likely no token coverage): {e}")
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
alignment = result.alignments.get(self.matching_method, [])
|
|
354
|
+
|
|
355
|
+
src_to_tgt = {}
|
|
356
|
+
for src_idx, tgt_idx in alignment:
|
|
357
|
+
src_to_tgt.setdefault(src_idx, []).append(tgt_idx)
|
|
358
|
+
|
|
359
|
+
# Source tokens covered by current span
|
|
360
|
+
covered_src = [i for i, t in enumerate(result.src_tokens)
|
|
361
|
+
if max(span['start'], t.start) < min(span['end'], t.end)]
|
|
362
|
+
|
|
363
|
+
mapped_tgt = []
|
|
364
|
+
for s_idx in covered_src:
|
|
365
|
+
mapped_tgt.extend(src_to_tgt.get(s_idx, []))
|
|
366
|
+
|
|
367
|
+
if not mapped_tgt:
|
|
368
|
+
continue
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
clusters = self._find_contiguous_clusters(mapped_tgt, max_gap)
|
|
372
|
+
clusters = self._fill_gaps_in_clusters(clusters)
|
|
373
|
+
largest = self._cluster_merge(clusters, src_token_count=len(covered_src), base_gap=max_gap, mass_factor=0.6)
|
|
374
|
+
|
|
375
|
+
#largest = max(clusters, key=len)
|
|
376
|
+
|
|
377
|
+
min_idx, max_idx = min(largest), max(largest)
|
|
378
|
+
tgt_start = result.tgt_tokens[min_idx].start
|
|
379
|
+
tgt_end = result.tgt_tokens[max_idx].end
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
if self.verbose:
|
|
384
|
+
print(f"1.2 Performing clustering.")
|
|
385
|
+
print(f"1.3 No gap results: {self._find_contiguous_clusters(mapped_tgt, max_gap=0)}")
|
|
386
|
+
print(f"1.4 Computed Alignments with result:", clusters)
|
|
387
|
+
print(f"1.5 Merged clusters into largest:", largest)
|
|
388
|
+
print(f"1.6 Covered source tokens: {[result.src_tokens[i].text for i in covered_src]}")
|
|
389
|
+
print(f"1.7 Mapped target tokens: {[result.tgt_tokens[i].text for i in mapped_tgt]}")
|
|
390
|
+
print(f"1.8 Projected span: [{min_idx}->{tgt_start}, {max_idx}->{tgt_end})")
|
|
391
|
+
print(f"1.9 Sim matrix shape: {result.similarity_matrix.shape}")
|
|
392
|
+
|
|
393
|
+
# Build token alignment details
|
|
394
|
+
token_alignments = []
|
|
395
|
+
for src_idx in covered_src:
|
|
396
|
+
best_tgt, best_score = None, 0
|
|
397
|
+
for tgt_idx in largest:
|
|
398
|
+
if src_idx < result.similarity_matrix.shape[0] and tgt_idx < result.similarity_matrix.shape[1]:
|
|
399
|
+
score = result.similarity_matrix[src_idx, tgt_idx]
|
|
400
|
+
if score > best_score:
|
|
401
|
+
best_score, best_tgt = score, tgt_idx
|
|
402
|
+
|
|
403
|
+
# Check cross-similarity for display
|
|
404
|
+
best_tgt_any = None
|
|
405
|
+
best_score_any = 0
|
|
406
|
+
for ti in range(len(result.tgt_tokens)):
|
|
407
|
+
if src_idx < result.similarity_matrix.shape[0] and ti < result.similarity_matrix.shape[1]:
|
|
408
|
+
if result.similarity_matrix[src_idx, ti] > best_score_any:
|
|
409
|
+
best_score_any = result.similarity_matrix[src_idx, ti]
|
|
410
|
+
best_tgt_any = ti
|
|
411
|
+
|
|
412
|
+
token_alignments.append({
|
|
413
|
+
'src_idx': src_idx,
|
|
414
|
+
'src_token': result.src_tokens[src_idx].text if src_idx < len(result.src_tokens) else '?',
|
|
415
|
+
'tgt_idx': best_tgt if best_tgt is not None else best_tgt_any,
|
|
416
|
+
'tgt_token': result.tgt_tokens[best_tgt].text if best_tgt is not None else (result.tgt_tokens[best_tgt_any].text if best_tgt_any is not None else '?'),
|
|
417
|
+
'score': float(best_score if best_tgt is not None else best_score_any),
|
|
418
|
+
'in_cluster': best_tgt is not None
|
|
419
|
+
})
|
|
420
|
+
|
|
421
|
+
projected.append({
|
|
422
|
+
"start": tgt_start,
|
|
423
|
+
"end": tgt_end,
|
|
424
|
+
"text": tgt_text[tgt_start:tgt_end],
|
|
425
|
+
"labels": span.get("labels", []),
|
|
426
|
+
"cluster_size": len(largest),
|
|
427
|
+
"total_aligned": len(mapped_tgt), # Tokens in src mapped to something in tgt
|
|
428
|
+
"num_clusters": len(clusters),
|
|
429
|
+
"token_alignments": token_alignments,
|
|
430
|
+
"all_clusters": clusters
|
|
431
|
+
})
|
|
432
|
+
|
|
433
|
+
return projected
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def project_tagged_text(self,
|
|
437
|
+
src_text: str,
|
|
438
|
+
tgt_text: str,
|
|
439
|
+
allowed_tags: Optional[List[str]] = None,
|
|
440
|
+
max_gap: int = 1,
|
|
441
|
+
debugging: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
|
|
442
|
+
"""
|
|
443
|
+
Project annotations from tagged source text to target text.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
src_text: Tagged source text
|
|
447
|
+
tgt_text: Target text (raw)
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
Tuple of (Tagged Target Text, Projected Spans)
|
|
451
|
+
"""
|
|
452
|
+
parsed = SpanAligner.tagged_text_to_task(src_text, allowed_tags=allowed_tags)
|
|
453
|
+
raw_src_text = parsed["task"]["data"]["text"]
|
|
454
|
+
src_spans = parsed["spans"] + parsed["entities"]
|
|
455
|
+
|
|
456
|
+
#print(f"Projecting following text:\nSRC: {raw_src_text}\nTGT: {tgt_text}\nWith spans: {src_spans}")
|
|
457
|
+
|
|
458
|
+
projected = self.project_spans(raw_src_text, tgt_text, src_spans, max_gap=max_gap, debugging=debugging)
|
|
459
|
+
tagged_tgt, _ = SpanAligner.rebuild_tagged_text(tgt_text, projected, [])
|
|
460
|
+
|
|
461
|
+
return tagged_tgt, projected
|
|
462
|
+
|
|
463
|
+
def project_tagged_text_long(self,
|
|
464
|
+
src_text: str,
|
|
465
|
+
tgt_text: str,
|
|
466
|
+
allowed_tags: Optional[List[str]] = None,
|
|
467
|
+
window_size: int = 512,
|
|
468
|
+
max_gap: int = 1,
|
|
469
|
+
debugging: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
|
|
470
|
+
"""Project annotations for long documents using token window alignment."""
|
|
471
|
+
if allowed_tags is None:
|
|
472
|
+
allowed_tags = list(set(re.findall(r"<([a-zA-Z_][a-zA-Z0-9_-]*)>", src_text)))
|
|
473
|
+
|
|
474
|
+
parsed = SpanAligner.tagged_text_to_task(src_text, allowed_tags=allowed_tags)
|
|
475
|
+
raw_src_text = parsed["task"]["data"]["text"]
|
|
476
|
+
all_spans = parsed["spans"] + parsed["entities"]
|
|
477
|
+
|
|
478
|
+
if not all_spans:
|
|
479
|
+
return tgt_text, []
|
|
480
|
+
|
|
481
|
+
import bisect
|
|
482
|
+
|
|
483
|
+
# 1. Tokenize both the src and target to get character offsets
|
|
484
|
+
def get_tokens(text: str) -> List[Tuple[int, int]]:
|
|
485
|
+
"""Return list of (start, end) character offsets for tokens."""
|
|
486
|
+
# Simple regex tokenizer that captures words and punctuation
|
|
487
|
+
return [m.span() for m in re.finditer(r'\w+|[^\w\s]', text)]
|
|
488
|
+
|
|
489
|
+
src_tokens = get_tokens(raw_src_text)
|
|
490
|
+
tgt_tokens = get_tokens(tgt_text)
|
|
491
|
+
|
|
492
|
+
if not src_tokens or not tgt_tokens:
|
|
493
|
+
return tgt_text, []
|
|
494
|
+
|
|
495
|
+
src_starts = [t[0] for t in src_tokens]
|
|
496
|
+
projected = []
|
|
497
|
+
|
|
498
|
+
# 2. Iterate over the spans of the src
|
|
499
|
+
half_window = window_size // 2
|
|
500
|
+
|
|
501
|
+
for span in all_spans:
|
|
502
|
+
# Get relative position in text (center of span)
|
|
503
|
+
span_center = (span["start"] + span["end"]) / 2
|
|
504
|
+
|
|
505
|
+
# Find closest token index in src
|
|
506
|
+
center_idx = bisect.bisect_left(src_starts, span_center)
|
|
507
|
+
# bisect_left gives index where it could be inserted.
|
|
508
|
+
# If it's exact match or after, we might want index-1 or index.
|
|
509
|
+
# Let's just constrain it to bounds.
|
|
510
|
+
center_idx = max(0, min(len(src_tokens) - 1, center_idx))
|
|
511
|
+
|
|
512
|
+
# Create window of tokens around the span
|
|
513
|
+
src_start_idx = max(0, center_idx - half_window)
|
|
514
|
+
src_end_idx = min(len(src_tokens), center_idx + half_window)
|
|
515
|
+
|
|
516
|
+
# Get character ranges for source window
|
|
517
|
+
win_src_start_char = src_tokens[src_start_idx][0]
|
|
518
|
+
win_src_end_char = src_tokens[src_end_idx-1][1]
|
|
519
|
+
src_window_text = raw_src_text[win_src_start_char:win_src_end_char]
|
|
520
|
+
|
|
521
|
+
# 3. Get the corresponding window in the tgt tokens
|
|
522
|
+
# Map relative position: (src_idx / src_len) -> tgt_idx
|
|
523
|
+
if len(src_tokens) > 1:
|
|
524
|
+
rel_pos = center_idx / (len(src_tokens) - 1)
|
|
525
|
+
else:
|
|
526
|
+
rel_pos = 0.5
|
|
527
|
+
|
|
528
|
+
tgt_center_idx = int(rel_pos * len(tgt_tokens))
|
|
529
|
+
tgt_start_idx = max(0, tgt_center_idx - half_window)
|
|
530
|
+
tgt_end_idx = min(len(tgt_tokens), tgt_center_idx + half_window)
|
|
531
|
+
|
|
532
|
+
win_tgt_start_char = tgt_tokens[tgt_start_idx][0]
|
|
533
|
+
win_tgt_end_char = tgt_tokens[tgt_end_idx-1][1]
|
|
534
|
+
tgt_window_text = tgt_text[win_tgt_start_char:win_tgt_end_char]
|
|
535
|
+
|
|
536
|
+
# 4. Perform apply project_tagged_text
|
|
537
|
+
# We need to adjust the span to be relative to the source window
|
|
538
|
+
rel_span = span.copy()
|
|
539
|
+
rel_span["start"] -= win_src_start_char
|
|
540
|
+
rel_span["end"] -= win_src_start_char
|
|
541
|
+
|
|
542
|
+
# Sanity check: span must be within extracted text
|
|
543
|
+
# (It should be, unless window is smaller than the span itself)
|
|
544
|
+
if rel_span["start"] < 0 or rel_span["end"] > len(src_window_text):
|
|
545
|
+
continue
|
|
546
|
+
|
|
547
|
+
# Project just this span within the window
|
|
548
|
+
# We pass a list containing just the single span
|
|
549
|
+
window_projections = self.project_spans(
|
|
550
|
+
src_window_text,
|
|
551
|
+
tgt_window_text,
|
|
552
|
+
[rel_span],
|
|
553
|
+
max_gap=max_gap,
|
|
554
|
+
debugging=debugging
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# 5. Map the positions of the projected spans back onto the full text
|
|
558
|
+
for p in window_projections:
|
|
559
|
+
p["start"] += win_tgt_start_char
|
|
560
|
+
p["end"] += win_tgt_start_char
|
|
561
|
+
projected.append(p)
|
|
562
|
+
|
|
563
|
+
tagged_tgt, _ = SpanAligner.rebuild_tagged_text(tgt_text, projected, [])
|
|
564
|
+
return tagged_tgt, projected
|
|
565
|
+
|
|
566
|
+
def _visualize_projection(self, projected_spans: List[Dict[str, Any]]) -> str:
|
|
567
|
+
"""Generate visualization of projected spans with alignment info."""
|
|
568
|
+
lines = ["=" * 80, "PROJECTION VISUALIZATION", "=" * 80]
|
|
569
|
+
|
|
570
|
+
for i, span in enumerate(projected_spans):
|
|
571
|
+
lines.append(f"\nSPAN {i}: {span.get('labels', [])}")
|
|
572
|
+
lines.append("-" * 80)
|
|
573
|
+
lines.append(f"Projected: '{span['text']}' [{span['start']}, {span['end']})")
|
|
574
|
+
|
|
575
|
+
if 'num_clusters' in span:
|
|
576
|
+
lines.append(f"Clusters: {span['num_clusters']} (selected: {span['cluster_size']} tokens)")
|
|
577
|
+
|
|
578
|
+
if 'token_alignments' in span:
|
|
579
|
+
lines.append("Token alignments:")
|
|
580
|
+
for a in span['token_alignments']:
|
|
581
|
+
marker = "✅" if a['score'] >= 0.95 else "✓" if a['score'] >= 0.85 else "⚠️" if a['score'] >= 0.7 else "❌"
|
|
582
|
+
tgt_part = f"-> '{a['tgt_token']}'" if a.get('in_cluster') else f"(best match: '{a['tgt_token']}')"
|
|
583
|
+
lines.append(f" {marker} '{a['src_token']}' {tgt_part} ({a['score']:.3f})")
|
|
584
|
+
|
|
585
|
+
return "\n".join(lines)
|