span-aligner 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- span_aligner/__init__.py +3 -3
- span_aligner/sentence_aligner.py +713 -0
- span_aligner/span_aligner.py +1178 -0
- span_aligner/span_projector.py +585 -0
- span_aligner-0.2.0.dist-info/METADATA +283 -0
- span_aligner-0.2.0.dist-info/RECORD +10 -0
- {span_aligner-0.1.2.dist-info → span_aligner-0.2.0.dist-info}/WHEEL +1 -1
- span_aligner-0.1.2.dist-info/METADATA +0 -169
- span_aligner-0.1.2.dist-info/RECORD +0 -7
- {span_aligner-0.1.2.dist-info → span_aligner-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {span_aligner-0.1.2.dist-info → span_aligner-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,713 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Dict, List, Tuple, Union, Optional
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy.sparse import csr_matrix
|
|
12
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
13
|
+
try:
|
|
14
|
+
import networkx as nx
|
|
15
|
+
from networkx.algorithms.bipartite.matrix import from_biadjacency_matrix
|
|
16
|
+
except ImportError:
|
|
17
|
+
nx = None
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# =============================================================================
|
|
22
|
+
# DATA STRUCTURES
|
|
23
|
+
# =============================================================================
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class TokenInfo:
|
|
27
|
+
"""Token with text and character offsets."""
|
|
28
|
+
text: str
|
|
29
|
+
start: int
|
|
30
|
+
end: int
|
|
31
|
+
idx: int
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class AlignmentResult:
|
|
36
|
+
"""Complete alignment result with all necessary data for projection."""
|
|
37
|
+
alignments: Dict[str, List[Tuple[int, int]]] # method -> [(src_idx, tgt_idx), ...]
|
|
38
|
+
src_tokens: List[TokenInfo]
|
|
39
|
+
tgt_tokens: List[TokenInfo]
|
|
40
|
+
similarity_matrix: np.ndarray
|
|
41
|
+
src_vectors: np.ndarray
|
|
42
|
+
tgt_vectors: np.ndarray
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# =============================================================================
|
|
46
|
+
# EMBEDDING PROVIDERS - Pluggable embedding backends
|
|
47
|
+
# =============================================================================
|
|
48
|
+
|
|
49
|
+
class EmbeddingProvider(ABC):
|
|
50
|
+
"""Abstract base class for embedding providers."""
|
|
51
|
+
|
|
52
|
+
# Default word tokenization pattern
|
|
53
|
+
WORD_PATTERN = re.compile(r'\b\w+\b|[^\s\w]')
|
|
54
|
+
|
|
55
|
+
def tokenize_text(self, text: str) -> List[TokenInfo]:
|
|
56
|
+
"""Tokenize text into tokens with character offsets.
|
|
57
|
+
|
|
58
|
+
This is the canonical tokenization used throughout the pipeline.
|
|
59
|
+
"""
|
|
60
|
+
tokens = []
|
|
61
|
+
for match in self.WORD_PATTERN.finditer(text):
|
|
62
|
+
tokens.append(TokenInfo(
|
|
63
|
+
text=match.group(),
|
|
64
|
+
start=match.start(),
|
|
65
|
+
end=match.end(),
|
|
66
|
+
idx=len(tokens)
|
|
67
|
+
))
|
|
68
|
+
return tokens
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def get_embeddings(self, tokens: List[str]) -> np.ndarray:
|
|
72
|
+
"""Get embeddings for a list of tokens."""
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
|
|
77
|
+
"""Get subword tokens and mapping to original word indices."""
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class TransformerEmbeddingProvider(EmbeddingProvider):
|
|
82
|
+
"""Embedding provider using HuggingFace transformers."""
|
|
83
|
+
|
|
84
|
+
_model_cache = {}
|
|
85
|
+
|
|
86
|
+
MODEL_CONFIGS = {
|
|
87
|
+
'bert': 'bert-base-multilingual-cased',
|
|
88
|
+
'xlmr': 'xlm-roberta-base',
|
|
89
|
+
'xlmr-large': 'xlm-roberta-large',
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
def __init__(self, model: str = "bert", device: str = "cpu", layer: int = 8):
|
|
93
|
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
|
94
|
+
|
|
95
|
+
self.device = torch.device(device)
|
|
96
|
+
self.layer = layer
|
|
97
|
+
|
|
98
|
+
model_name = self.MODEL_CONFIGS.get(model, model)
|
|
99
|
+
self.model_name = model_name
|
|
100
|
+
|
|
101
|
+
cache_key = f"transformer_{model_name}_{device}"
|
|
102
|
+
if cache_key not in self._model_cache:
|
|
103
|
+
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
|
|
104
|
+
emb_model = AutoModel.from_pretrained(model_name, config=config)
|
|
105
|
+
emb_model.eval()
|
|
106
|
+
emb_model.to(self.device)
|
|
107
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
108
|
+
self._model_cache[cache_key] = (emb_model, tokenizer)
|
|
109
|
+
|
|
110
|
+
self.emb_model, self.tokenizer = self._model_cache[cache_key]
|
|
111
|
+
|
|
112
|
+
def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
|
|
113
|
+
subwords = []
|
|
114
|
+
word_map = []
|
|
115
|
+
for i, word in enumerate(words):
|
|
116
|
+
tokens = self.tokenizer.tokenize(word)
|
|
117
|
+
subwords.extend(tokens)
|
|
118
|
+
word_map.extend([i] * len(tokens))
|
|
119
|
+
return subwords, word_map
|
|
120
|
+
|
|
121
|
+
def get_embeddings(self, tokens: List[str]) -> np.ndarray:
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
inputs = self.tokenizer(tokens, is_split_into_words=True,
|
|
124
|
+
padding=True, truncation=True, return_tensors="pt")
|
|
125
|
+
hidden = self.emb_model(**inputs.to(self.device))["hidden_states"]
|
|
126
|
+
if self.layer >= len(hidden):
|
|
127
|
+
raise ValueError(f"Layer {self.layer} requested but model has only {len(hidden)} layers.")
|
|
128
|
+
outputs = hidden[self.layer][:, 1:-1, :]
|
|
129
|
+
return outputs.cpu().numpy()
|
|
130
|
+
|
|
131
|
+
def get_embeddings_batch(self, batch: List[List[str]]) -> List[np.ndarray]:
|
|
132
|
+
"""Get embeddings for a batch of token lists."""
|
|
133
|
+
with torch.no_grad():
|
|
134
|
+
inputs = self.tokenizer(batch, is_split_into_words=True,
|
|
135
|
+
padding=True, truncation=True, return_tensors="pt")
|
|
136
|
+
hidden = self.emb_model(**inputs.to(self.device))["hidden_states"]
|
|
137
|
+
outputs = hidden[self.layer][:, 1:-1, :]
|
|
138
|
+
|
|
139
|
+
results = []
|
|
140
|
+
for i, tokens in enumerate(batch):
|
|
141
|
+
subwords = []
|
|
142
|
+
for word in tokens:
|
|
143
|
+
subwords.extend(self.tokenizer.tokenize(word))
|
|
144
|
+
seq_len = len(subwords)
|
|
145
|
+
results.append(outputs[i, :seq_len].cpu().numpy())
|
|
146
|
+
return results
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class OllamaEmbeddingProvider(EmbeddingProvider):
|
|
150
|
+
"""Embedding provider using Ollama API (new /api/embed endpoint)."""
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
model: str = "nomic-embed-text",
|
|
155
|
+
base_url: str = "http://localhost:11434",
|
|
156
|
+
):
|
|
157
|
+
try:
|
|
158
|
+
import requests
|
|
159
|
+
except ImportError:
|
|
160
|
+
raise ImportError("requests package required for Ollama provider")
|
|
161
|
+
|
|
162
|
+
self.model = model
|
|
163
|
+
self.base_url = base_url.rstrip("/")
|
|
164
|
+
self._requests = requests
|
|
165
|
+
|
|
166
|
+
def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
|
|
167
|
+
return words, list(range(len(words)))
|
|
168
|
+
|
|
169
|
+
def get_embeddings(self, tokens: List[str]) -> np.ndarray:
|
|
170
|
+
if not tokens:
|
|
171
|
+
return np.empty((0, 0))
|
|
172
|
+
|
|
173
|
+
response = self._requests.post(
|
|
174
|
+
f"{self.base_url}/api/embed",
|
|
175
|
+
json={
|
|
176
|
+
"model": self.model,
|
|
177
|
+
"input": tokens,
|
|
178
|
+
},
|
|
179
|
+
timeout=60,
|
|
180
|
+
)
|
|
181
|
+
response.raise_for_status()
|
|
182
|
+
|
|
183
|
+
data = response.json()
|
|
184
|
+
|
|
185
|
+
# Ollama returns: { "embeddings": [ [...], [...], ... ] }
|
|
186
|
+
embeddings = data["embeddings"]
|
|
187
|
+
return np.array(embeddings, dtype=np.float32)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class SentenceTransformerProvider(EmbeddingProvider):
|
|
192
|
+
"""Embedding provider using sentence-transformers."""
|
|
193
|
+
|
|
194
|
+
_model_cache = {}
|
|
195
|
+
|
|
196
|
+
def __init__(self, model: str = "paraphrase-multilingual-MiniLM-L12-v2", device: str = "cpu"):
|
|
197
|
+
try:
|
|
198
|
+
from sentence_transformers import SentenceTransformer
|
|
199
|
+
except ImportError:
|
|
200
|
+
raise ImportError("sentence-transformers package required")
|
|
201
|
+
|
|
202
|
+
self.device = device
|
|
203
|
+
cache_key = f"sbert_{model}"
|
|
204
|
+
if cache_key not in self._model_cache:
|
|
205
|
+
self._model_cache[cache_key] = SentenceTransformer(model, device=device)
|
|
206
|
+
self.model = self._model_cache[cache_key]
|
|
207
|
+
|
|
208
|
+
def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
|
|
209
|
+
return words, list(range(len(words)))
|
|
210
|
+
|
|
211
|
+
def get_embeddings(self, tokens: List[str]) -> np.ndarray:
|
|
212
|
+
return self.model.encode(tokens, convert_to_numpy=True)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# =============================================================================
|
|
216
|
+
# SENTENCE ALIGNER - Core alignment logic
|
|
217
|
+
# =============================================================================
|
|
218
|
+
|
|
219
|
+
class SentenceAligner:
|
|
220
|
+
"""Word alignment using contextual embeddings and various matching algorithms."""
|
|
221
|
+
|
|
222
|
+
MATCHING_METHODS = {"a": "inter", "m": "mwmf", "i": "itermax", "f": "fwd", "r": "rev",
|
|
223
|
+
"g": "greedy", "t": "threshold"}
|
|
224
|
+
|
|
225
|
+
def __init__(self,
|
|
226
|
+
embedding_provider: Optional[EmbeddingProvider] = None,
|
|
227
|
+
model: str = "bert",
|
|
228
|
+
token_type: str = "bpe",
|
|
229
|
+
distortion: float = 0.0,
|
|
230
|
+
matching_methods: str = "mai",
|
|
231
|
+
device: str = "cpu",
|
|
232
|
+
layer: int = 8):
|
|
233
|
+
"""
|
|
234
|
+
Initialize SentenceAligner.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
embedding_provider: Optional pre-configured EmbeddingProvider instance.
|
|
238
|
+
model: Model name (used if embedding_provider is None)
|
|
239
|
+
token_type: "bpe" for subword alignment, "word" for word-level
|
|
240
|
+
distortion: Position distortion factor (0.0 = no distortion)
|
|
241
|
+
matching_methods: String of method codes
|
|
242
|
+
device: Device for computation
|
|
243
|
+
layer: Transformer layer to extract embeddings from
|
|
244
|
+
"""
|
|
245
|
+
self.token_type = token_type
|
|
246
|
+
self.distortion = distortion
|
|
247
|
+
self.matching_methods = [self.MATCHING_METHODS[m] for m in matching_methods if m in self.MATCHING_METHODS]
|
|
248
|
+
|
|
249
|
+
if embedding_provider is not None:
|
|
250
|
+
self.embed_provider = embedding_provider
|
|
251
|
+
else:
|
|
252
|
+
self.embed_provider = TransformerEmbeddingProvider(model=model, device=device, layer=layer)
|
|
253
|
+
|
|
254
|
+
# -------------------------------------------------------------------------
|
|
255
|
+
# Static alignment algorithms
|
|
256
|
+
# -------------------------------------------------------------------------
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def get_similarity(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
|
|
260
|
+
return (cosine_similarity(X, Y) + 1.0) / 2.0
|
|
261
|
+
|
|
262
|
+
@staticmethod
|
|
263
|
+
def apply_distortion(sim_matrix: np.ndarray, ratio: float = 0.5) -> np.ndarray:
|
|
264
|
+
shape = sim_matrix.shape
|
|
265
|
+
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
|
|
266
|
+
return sim_matrix
|
|
267
|
+
pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])])
|
|
268
|
+
pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])])
|
|
269
|
+
distortion_mask = 1.0 - ((pos_x - np.transpose(pos_y)) ** 2) * ratio
|
|
270
|
+
return np.multiply(sim_matrix, distortion_mask)
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def get_alignment_matrix(sim_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
274
|
+
m, n = sim_matrix.shape
|
|
275
|
+
forward = np.eye(n)[sim_matrix.argmax(axis=1)]
|
|
276
|
+
backward = np.eye(m)[sim_matrix.argmax(axis=0)]
|
|
277
|
+
return forward, backward.transpose()
|
|
278
|
+
|
|
279
|
+
@staticmethod
|
|
280
|
+
def get_max_weight_match(sim: np.ndarray) -> np.ndarray:
|
|
281
|
+
if nx is None:
|
|
282
|
+
raise ValueError("networkx must be installed to use mwmf algorithm.")
|
|
283
|
+
def permute(edge):
|
|
284
|
+
if edge[0] < sim.shape[0]:
|
|
285
|
+
return edge[0], edge[1] - sim.shape[0]
|
|
286
|
+
else:
|
|
287
|
+
return edge[1], edge[0] - sim.shape[0]
|
|
288
|
+
G = from_biadjacency_matrix(csr_matrix(sim))
|
|
289
|
+
matching = nx.max_weight_matching(G, maxcardinality=True)
|
|
290
|
+
matching = [permute(x) for x in matching]
|
|
291
|
+
res_matrix = np.zeros_like(sim)
|
|
292
|
+
for edge in matching:
|
|
293
|
+
res_matrix[edge[0], edge[1]] = 1
|
|
294
|
+
return res_matrix
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def iter_max(sim_matrix: np.ndarray, max_count: int = 2) -> np.ndarray:
|
|
298
|
+
alpha_ratio = 0.9
|
|
299
|
+
m, n = sim_matrix.shape
|
|
300
|
+
forward = np.eye(n)[sim_matrix.argmax(axis=1)]
|
|
301
|
+
backward = np.eye(m)[sim_matrix.argmax(axis=0)]
|
|
302
|
+
inter = forward * backward.transpose()
|
|
303
|
+
|
|
304
|
+
if min(m, n) <= 2:
|
|
305
|
+
return inter
|
|
306
|
+
|
|
307
|
+
count = 1
|
|
308
|
+
while count < max_count:
|
|
309
|
+
mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0)
|
|
310
|
+
mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0)
|
|
311
|
+
mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0)
|
|
312
|
+
mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y))
|
|
313
|
+
if mask_x.sum() < 1.0 or mask_y.sum() < 1.0:
|
|
314
|
+
mask *= 0.0
|
|
315
|
+
mask_zeros *= 0.0
|
|
316
|
+
|
|
317
|
+
new_sim = sim_matrix * mask
|
|
318
|
+
fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros
|
|
319
|
+
bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros
|
|
320
|
+
new_inter = fwd * bac
|
|
321
|
+
|
|
322
|
+
if np.array_equal(inter + new_inter, inter):
|
|
323
|
+
break
|
|
324
|
+
inter = inter + new_inter
|
|
325
|
+
count += 1
|
|
326
|
+
return inter
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def greedy_match(sim_matrix: np.ndarray, one_to_one: bool = True) -> np.ndarray:
|
|
330
|
+
m, n = sim_matrix.shape
|
|
331
|
+
result = np.zeros((m, n))
|
|
332
|
+
sim_flat = sim_matrix.flatten()
|
|
333
|
+
indices = np.argsort(-sim_flat)
|
|
334
|
+
|
|
335
|
+
used_src = set() if one_to_one else None
|
|
336
|
+
used_tgt = set() if one_to_one else None
|
|
337
|
+
|
|
338
|
+
for idx in indices:
|
|
339
|
+
i, j = idx // n, idx % n
|
|
340
|
+
if sim_matrix[i, j] <= 0:
|
|
341
|
+
break
|
|
342
|
+
if one_to_one:
|
|
343
|
+
if i in used_src or j in used_tgt:
|
|
344
|
+
continue
|
|
345
|
+
used_src.add(i)
|
|
346
|
+
used_tgt.add(j)
|
|
347
|
+
result[i, j] = 1
|
|
348
|
+
return result
|
|
349
|
+
|
|
350
|
+
@staticmethod
|
|
351
|
+
def threshold_match(sim_matrix: np.ndarray, threshold: float = 0.5) -> np.ndarray:
|
|
352
|
+
return (sim_matrix >= threshold).astype(float)
|
|
353
|
+
|
|
354
|
+
# -------------------------------------------------------------------------
|
|
355
|
+
# Main alignment methods
|
|
356
|
+
# -------------------------------------------------------------------------
|
|
357
|
+
|
|
358
|
+
def _prepare_tokens(self, sent: Union[str, List[str]]) -> Tuple[List[str], List[str], List[int]]:
|
|
359
|
+
"""Prepare tokens and get subword mapping."""
|
|
360
|
+
if isinstance(sent, str):
|
|
361
|
+
words = sent.split()
|
|
362
|
+
else:
|
|
363
|
+
words = sent
|
|
364
|
+
subwords, word_map = self.embed_provider.get_subword_to_word_map(words)
|
|
365
|
+
return words, subwords, word_map
|
|
366
|
+
|
|
367
|
+
def _compute_alignments(self, sim: np.ndarray) -> Dict[str, np.ndarray]:
|
|
368
|
+
"""Compute all requested alignment matrices."""
|
|
369
|
+
all_mats = {}
|
|
370
|
+
all_mats["fwd"], all_mats["rev"] = self.get_alignment_matrix(sim)
|
|
371
|
+
all_mats["inter"] = all_mats["fwd"] * all_mats["rev"]
|
|
372
|
+
|
|
373
|
+
if "mwmf" in self.matching_methods:
|
|
374
|
+
all_mats["mwmf"] = self.get_max_weight_match(sim)
|
|
375
|
+
if "itermax" in self.matching_methods:
|
|
376
|
+
all_mats["itermax"] = self.iter_max(sim)
|
|
377
|
+
if "greedy" in self.matching_methods:
|
|
378
|
+
all_mats["greedy"] = self.greedy_match(sim)
|
|
379
|
+
if "threshold" in self.matching_methods:
|
|
380
|
+
all_mats["threshold"] = self.threshold_match(sim)
|
|
381
|
+
|
|
382
|
+
return all_mats
|
|
383
|
+
|
|
384
|
+
def _average_embeds_over_words(self,
|
|
385
|
+
bpe_vectors: List[np.ndarray],
|
|
386
|
+
words: List[List[str]],
|
|
387
|
+
subwords: List[List[str]]) -> List[np.ndarray]:
|
|
388
|
+
"""Average subword embeddings to get word-level embeddings."""
|
|
389
|
+
new_vectors = []
|
|
390
|
+
for lang_idx in range(2):
|
|
391
|
+
word_list = words[lang_idx]
|
|
392
|
+
_, word_map = self.embed_provider.get_subword_to_word_map(word_list)
|
|
393
|
+
|
|
394
|
+
word_vectors = []
|
|
395
|
+
for word_idx in range(len(word_list)):
|
|
396
|
+
subword_indices = [i for i, w in enumerate(word_map) if w == word_idx]
|
|
397
|
+
if subword_indices and max(subword_indices) < bpe_vectors[lang_idx].shape[0]:
|
|
398
|
+
word_vectors.append(bpe_vectors[lang_idx][subword_indices].mean(0))
|
|
399
|
+
else:
|
|
400
|
+
word_vectors.append(np.zeros(bpe_vectors[lang_idx].shape[1]))
|
|
401
|
+
|
|
402
|
+
new_vectors.append(np.array(word_vectors) if word_vectors else np.zeros((0, bpe_vectors[lang_idx].shape[1])))
|
|
403
|
+
return new_vectors
|
|
404
|
+
|
|
405
|
+
def get_text_embeddings(self, text: str) -> Tuple[List[TokenInfo], np.ndarray]:
|
|
406
|
+
"""
|
|
407
|
+
Get embeddings for a text string.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
text: Input text
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
Tuple of (tokens, embeddings)
|
|
414
|
+
"""
|
|
415
|
+
tokens = self.embed_provider.tokenize_text(text)
|
|
416
|
+
words = [t.text for t in tokens]
|
|
417
|
+
|
|
418
|
+
# Get subword mappings
|
|
419
|
+
subwords, word_map = self.embed_provider.get_subword_to_word_map(words)
|
|
420
|
+
|
|
421
|
+
# Get embeddings
|
|
422
|
+
if hasattr(self.embed_provider, 'get_embeddings_batch'):
|
|
423
|
+
vectors = self.embed_provider.get_embeddings_batch([words])[0]
|
|
424
|
+
else:
|
|
425
|
+
vectors = self.embed_provider.get_embeddings(words)
|
|
426
|
+
|
|
427
|
+
# Truncate to subword lengths
|
|
428
|
+
vectors = vectors[:len(subwords)]
|
|
429
|
+
|
|
430
|
+
# Average over words if needed
|
|
431
|
+
if self.token_type == "word":
|
|
432
|
+
vectors = self._average_embeds_over_words(
|
|
433
|
+
[vectors, vectors], # Hack: pass twice to satisfy list expectation
|
|
434
|
+
[words, words],
|
|
435
|
+
[subwords, subwords]
|
|
436
|
+
)[0]
|
|
437
|
+
|
|
438
|
+
return tokens, vectors
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# Helper method to get embeddings and similarity
|
|
442
|
+
def _get_embeddings_and_similarity(self, src_words: List[str], tgt_words: List[str],
|
|
443
|
+
compute_sim: bool = True) -> Tuple[List[np.ndarray], Optional[np.ndarray]]:
|
|
444
|
+
"""Get embeddings for source and target words, optionally compute similarity."""
|
|
445
|
+
# Get subword mappings
|
|
446
|
+
src_subwords, _ = self.embed_provider.get_subword_to_word_map(src_words)
|
|
447
|
+
tgt_subwords, _ = self.embed_provider.get_subword_to_word_map(tgt_words)
|
|
448
|
+
|
|
449
|
+
# Get embeddings
|
|
450
|
+
if hasattr(self.embed_provider, 'get_embeddings_batch'):
|
|
451
|
+
vectors = self.embed_provider.get_embeddings_batch([src_words, tgt_words])
|
|
452
|
+
else:
|
|
453
|
+
vectors = [
|
|
454
|
+
self.embed_provider.get_embeddings(src_words),
|
|
455
|
+
self.embed_provider.get_embeddings(tgt_words)
|
|
456
|
+
]
|
|
457
|
+
|
|
458
|
+
# Truncate to subword lengths
|
|
459
|
+
vectors = [vectors[0][:len(src_subwords)], vectors[1][:len(tgt_subwords)]]
|
|
460
|
+
|
|
461
|
+
# Average over words if needed
|
|
462
|
+
if self.token_type == "word":
|
|
463
|
+
vectors = self._average_embeds_over_words(
|
|
464
|
+
vectors, [src_words, tgt_words], [src_subwords, tgt_subwords]
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# Compute similarity if requested
|
|
468
|
+
sim = None
|
|
469
|
+
if compute_sim:
|
|
470
|
+
sim = self.get_similarity(vectors[0], vectors[1])
|
|
471
|
+
sim = self.apply_distortion(sim, self.distortion)
|
|
472
|
+
|
|
473
|
+
return vectors, sim
|
|
474
|
+
|
|
475
|
+
def _get_subword_range(self, word_start: int, word_end: int,
|
|
476
|
+
w2b: List[int]) -> Tuple[int, int]:
|
|
477
|
+
"""Convert word range to subword range."""
|
|
478
|
+
if self.token_type == "bpe":
|
|
479
|
+
sub_start = next((i for i, w in enumerate(w2b) if w >= word_start), 0)
|
|
480
|
+
sub_end = next((i for i, w in enumerate(w2b) if w >= word_end), len(w2b))
|
|
481
|
+
return sub_start, sub_end
|
|
482
|
+
return word_start, word_end
|
|
483
|
+
|
|
484
|
+
def _get_alignments_from_similarity(self, sim: np.ndarray, vectors: List[np.ndarray],
|
|
485
|
+
src_w2b: List[int], tgt_w2b: List[int],
|
|
486
|
+
n_src_sub: int, n_tgt_sub: int,
|
|
487
|
+
n_src_word: int, n_tgt_word: int,
|
|
488
|
+
tgt_offset: int = 0) -> Dict[str, List]:
|
|
489
|
+
"""
|
|
490
|
+
Convert similarity matrix to alignment tuples.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
sim: Similarity matrix
|
|
494
|
+
vectors: Source and target embedding vectors
|
|
495
|
+
src_w2b: Source subword-to-word mapping
|
|
496
|
+
tgt_w2b: Target subword-to-word mapping
|
|
497
|
+
n_src_sub: Number of source subwords
|
|
498
|
+
n_tgt_sub: Number of target subwords
|
|
499
|
+
n_src_word: Number of source words
|
|
500
|
+
n_tgt_word: Number of target words
|
|
501
|
+
tgt_offset: Offset for target indices (for partial alignment)
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
Dictionary mapping method names to sorted alignment lists
|
|
505
|
+
"""
|
|
506
|
+
all_mats = self._compute_alignments(sim)
|
|
507
|
+
aligns = {method: set() for method in self.matching_methods}
|
|
508
|
+
|
|
509
|
+
n_src = n_src_sub if self.token_type == "bpe" else n_src_word
|
|
510
|
+
n_tgt = n_tgt_sub if self.token_type == "bpe" else n_tgt_word
|
|
511
|
+
|
|
512
|
+
for i in range(min(vectors[0].shape[0], n_src)):
|
|
513
|
+
for j in range(min(vectors[1].shape[0], n_tgt)):
|
|
514
|
+
actual_tgt_idx = j + tgt_offset
|
|
515
|
+
|
|
516
|
+
for method in self.matching_methods:
|
|
517
|
+
if method in all_mats and all_mats[method][i, j] > 0:
|
|
518
|
+
if self.token_type == "bpe":
|
|
519
|
+
src_idx = src_w2b[i] if i < len(src_w2b) else i
|
|
520
|
+
tgt_idx = tgt_w2b[actual_tgt_idx] if actual_tgt_idx < len(tgt_w2b) else actual_tgt_idx
|
|
521
|
+
else:
|
|
522
|
+
src_idx = i
|
|
523
|
+
tgt_idx = actual_tgt_idx
|
|
524
|
+
|
|
525
|
+
aligns[method].add((src_idx, tgt_idx))
|
|
526
|
+
|
|
527
|
+
# Convert sets to sorted lists
|
|
528
|
+
return {method: sorted(align_set) for method, align_set in aligns.items()}
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
# Alignment functions
|
|
532
|
+
def get_word_aligns(self,
|
|
533
|
+
src_sent: Union[str, List[str]],
|
|
534
|
+
trg_sent: Union[str, List[str]]) -> Tuple[Dict[str, List], List[np.ndarray], np.ndarray]:
|
|
535
|
+
"""
|
|
536
|
+
Get word alignments between source and target sentences.
|
|
537
|
+
Legacy API - for backwards compatibility.
|
|
538
|
+
"""
|
|
539
|
+
src_words, src_subwords, src_w2b = self._prepare_tokens(src_sent)
|
|
540
|
+
tgt_words, tgt_subwords, tgt_w2b = self._prepare_tokens(trg_sent)
|
|
541
|
+
|
|
542
|
+
vectors, sim = self._get_embeddings_and_similarity(src_words, tgt_words)
|
|
543
|
+
|
|
544
|
+
aligns = self._get_alignments_from_similarity(sim, vectors, src_w2b, tgt_w2b,
|
|
545
|
+
len(src_subwords), len(tgt_subwords),
|
|
546
|
+
len(src_words), len(tgt_words))
|
|
547
|
+
|
|
548
|
+
return aligns, vectors, sim
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
# Main alignment entry points
|
|
553
|
+
def align_texts(self, src_text: str, tgt_text: str) -> AlignmentResult:
|
|
554
|
+
"""
|
|
555
|
+
Align two texts and return complete alignment result.
|
|
556
|
+
|
|
557
|
+
This is the main entry point that handles tokenization, embedding, and alignment.
|
|
558
|
+
All tokenization is done by the embedding provider for consistency.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
src_text: Source text string
|
|
562
|
+
tgt_text: Target text string
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
AlignmentResult with alignments, tokens, and similarity matrix
|
|
566
|
+
"""
|
|
567
|
+
print("1.1 Computing alignments...")
|
|
568
|
+
result = self.align_texts_partial(src_text, tgt_text, src_char_start=0, src_char_end=None)
|
|
569
|
+
print("1.2 Alignments computed.")
|
|
570
|
+
return result
|
|
571
|
+
|
|
572
|
+
def align_texts_partial(self, src_text: str, tgt_text: str,
|
|
573
|
+
src_char_start: int = 0,
|
|
574
|
+
src_char_end: Optional[int] = None) -> AlignmentResult:
|
|
575
|
+
"""
|
|
576
|
+
Align two texts with partial source range defined by character positions.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
src_text: Source text string
|
|
580
|
+
tgt_text: Target text string
|
|
581
|
+
src_char_start: Start index (char position) in source text
|
|
582
|
+
src_char_end: End index (char position) in source text (None = end of text)
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
AlignmentResult with alignments for the partial source range.
|
|
586
|
+
"""
|
|
587
|
+
# compute embeddings
|
|
588
|
+
src_tokens, src_vectors = self.get_text_embeddings(src_text)
|
|
589
|
+
tgt_tokens, tgt_vectors = self.get_text_embeddings(tgt_text)
|
|
590
|
+
|
|
591
|
+
return self.align_texts_partial_with_embeddings(
|
|
592
|
+
src_tokens, tgt_tokens, src_vectors, tgt_vectors, src_char_start, src_char_end
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def align_texts_partial_with_embeddings(self,
|
|
596
|
+
src_tokens: List[TokenInfo],
|
|
597
|
+
tgt_tokens: List[TokenInfo],
|
|
598
|
+
src_vectors: np.ndarray,
|
|
599
|
+
tgt_vectors: np.ndarray,
|
|
600
|
+
src_char_start: int,
|
|
601
|
+
src_char_end: Optional[int] = None) -> AlignmentResult:
|
|
602
|
+
"""
|
|
603
|
+
Align partial source text to target text using pre-computed embeddings and character positions.
|
|
604
|
+
"""
|
|
605
|
+
src_words = [t.text for t in src_tokens]
|
|
606
|
+
tgt_words = [t.text for t in tgt_tokens]
|
|
607
|
+
|
|
608
|
+
# Handle default end
|
|
609
|
+
if src_char_end is None:
|
|
610
|
+
if src_tokens:
|
|
611
|
+
src_char_end = src_tokens[-1].end
|
|
612
|
+
else:
|
|
613
|
+
src_char_end = 0
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
# Map char range to token range
|
|
617
|
+
src_start_idx = None
|
|
618
|
+
src_end_idx = None
|
|
619
|
+
|
|
620
|
+
# Simple mapping
|
|
621
|
+
for i, token in enumerate(src_tokens):
|
|
622
|
+
# Check for overlap
|
|
623
|
+
t_start = token.start
|
|
624
|
+
t_end = token.end
|
|
625
|
+
|
|
626
|
+
# Start index: first token that ends after char_start
|
|
627
|
+
if src_start_idx is None and t_end > src_char_start:
|
|
628
|
+
src_start_idx = i
|
|
629
|
+
|
|
630
|
+
# End index: last token that starts before char_end
|
|
631
|
+
if t_start < src_char_end:
|
|
632
|
+
src_end_idx = i + 1
|
|
633
|
+
|
|
634
|
+
if src_start_idx is None:
|
|
635
|
+
src_start_idx = 0
|
|
636
|
+
if src_end_idx is None:
|
|
637
|
+
src_end_idx = len(src_tokens)
|
|
638
|
+
|
|
639
|
+
# Logic from _align_partial_internal adapted:
|
|
640
|
+
src_subwords, src_w2b = self.embed_provider.get_subword_to_word_map(src_words)
|
|
641
|
+
tgt_subwords, tgt_w2b = self.embed_provider.get_subword_to_word_map(tgt_words)
|
|
642
|
+
|
|
643
|
+
# Determine source subword range
|
|
644
|
+
src_sub_start, src_sub_end = self._get_subword_range(src_start_idx, src_end_idx, src_w2b)
|
|
645
|
+
|
|
646
|
+
# Slice source vectors
|
|
647
|
+
src_partial_vecs = src_vectors[src_sub_start:src_sub_end]
|
|
648
|
+
src_w2b_partial = [idx - src_start_idx for idx in src_w2b[src_sub_start:src_sub_end]]
|
|
649
|
+
|
|
650
|
+
# Compute similarity
|
|
651
|
+
sim_partial = self.get_similarity(src_partial_vecs, tgt_vectors)
|
|
652
|
+
sim_partial = self.apply_distortion(sim_partial, self.distortion)
|
|
653
|
+
|
|
654
|
+
# Compute alignments
|
|
655
|
+
aligns = self._get_alignments_from_similarity(
|
|
656
|
+
sim_partial, [src_partial_vecs, tgt_vectors], src_w2b_partial, tgt_w2b,
|
|
657
|
+
len(src_w2b_partial), len(tgt_subwords),
|
|
658
|
+
src_end_idx - src_start_idx, len(tgt_words),
|
|
659
|
+
tgt_offset=0
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
partial_src_tokens = src_tokens[src_start_idx:src_end_idx]
|
|
663
|
+
|
|
664
|
+
return AlignmentResult(
|
|
665
|
+
alignments=aligns,
|
|
666
|
+
src_tokens=partial_src_tokens,
|
|
667
|
+
tgt_tokens=tgt_tokens,
|
|
668
|
+
similarity_matrix=sim_partial,
|
|
669
|
+
src_vectors=src_partial_vecs,
|
|
670
|
+
tgt_vectors=tgt_vectors
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
def align_texts_partial_substring(self, src_text: str, tgt_text: str,
|
|
674
|
+
src_substring: str) -> AlignmentResult:
|
|
675
|
+
"""
|
|
676
|
+
Align two texts using a substring to define the source range.
|
|
677
|
+
|
|
678
|
+
This method finds the substring in the source text and aligns only that portion.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
src_text: Source text string
|
|
682
|
+
tgt_text: Target text string
|
|
683
|
+
src_substring: Substring to find in source text
|
|
684
|
+
|
|
685
|
+
Returns:
|
|
686
|
+
AlignmentResult with alignments for the substring range
|
|
687
|
+
|
|
688
|
+
Raises:
|
|
689
|
+
ValueError: If substring is not found in source text
|
|
690
|
+
"""
|
|
691
|
+
# Find substring position
|
|
692
|
+
substring_start = src_text.find(src_substring)
|
|
693
|
+
if substring_start == -1:
|
|
694
|
+
raise ValueError(f"Substring '{src_substring}' not found in source text")
|
|
695
|
+
|
|
696
|
+
substring_end = substring_start + len(src_substring)
|
|
697
|
+
|
|
698
|
+
# Use align_texts_partial with computed indices
|
|
699
|
+
return self.align_texts_partial(src_text, tgt_text, substring_start, substring_end)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
# Debugging and visualization
|
|
703
|
+
def print_alignment(self, alignment_result: AlignmentResult, method: str = "inter") -> str:
|
|
704
|
+
"""print alignments in a human-readable format showing the source and target tokens with their indices and the aligned pairs according to the specified method."""
|
|
705
|
+
if method not in alignment_result.alignments:
|
|
706
|
+
return f"Method '{method}' not found in alignment results."
|
|
707
|
+
|
|
708
|
+
src_tokens = alignment_result.src_tokens
|
|
709
|
+
tgt_tokens = alignment_result.tgt_tokens
|
|
710
|
+
aligns = alignment_result.alignments[method]
|
|
711
|
+
|
|
712
|
+
for alignment in aligns:
|
|
713
|
+
print(f" {alignment}: {src_tokens[alignment[0]].text} <-> {tgt_tokens[alignment[1]].text}")
|