span-aligner 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,713 @@
1
+ # coding=utf-8
2
+
3
+ import os
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Tuple, Union, Optional
7
+ import re
8
+ from dataclasses import dataclass
9
+
10
+ import numpy as np
11
+ from scipy.sparse import csr_matrix
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ try:
14
+ import networkx as nx
15
+ from networkx.algorithms.bipartite.matrix import from_biadjacency_matrix
16
+ except ImportError:
17
+ nx = None
18
+ import torch
19
+
20
+
21
+ # =============================================================================
22
+ # DATA STRUCTURES
23
+ # =============================================================================
24
+
25
+ @dataclass
26
+ class TokenInfo:
27
+ """Token with text and character offsets."""
28
+ text: str
29
+ start: int
30
+ end: int
31
+ idx: int
32
+
33
+
34
+ @dataclass
35
+ class AlignmentResult:
36
+ """Complete alignment result with all necessary data for projection."""
37
+ alignments: Dict[str, List[Tuple[int, int]]] # method -> [(src_idx, tgt_idx), ...]
38
+ src_tokens: List[TokenInfo]
39
+ tgt_tokens: List[TokenInfo]
40
+ similarity_matrix: np.ndarray
41
+ src_vectors: np.ndarray
42
+ tgt_vectors: np.ndarray
43
+
44
+
45
+ # =============================================================================
46
+ # EMBEDDING PROVIDERS - Pluggable embedding backends
47
+ # =============================================================================
48
+
49
+ class EmbeddingProvider(ABC):
50
+ """Abstract base class for embedding providers."""
51
+
52
+ # Default word tokenization pattern
53
+ WORD_PATTERN = re.compile(r'\b\w+\b|[^\s\w]')
54
+
55
+ def tokenize_text(self, text: str) -> List[TokenInfo]:
56
+ """Tokenize text into tokens with character offsets.
57
+
58
+ This is the canonical tokenization used throughout the pipeline.
59
+ """
60
+ tokens = []
61
+ for match in self.WORD_PATTERN.finditer(text):
62
+ tokens.append(TokenInfo(
63
+ text=match.group(),
64
+ start=match.start(),
65
+ end=match.end(),
66
+ idx=len(tokens)
67
+ ))
68
+ return tokens
69
+
70
+ @abstractmethod
71
+ def get_embeddings(self, tokens: List[str]) -> np.ndarray:
72
+ """Get embeddings for a list of tokens."""
73
+ pass
74
+
75
+ @abstractmethod
76
+ def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
77
+ """Get subword tokens and mapping to original word indices."""
78
+ pass
79
+
80
+
81
+ class TransformerEmbeddingProvider(EmbeddingProvider):
82
+ """Embedding provider using HuggingFace transformers."""
83
+
84
+ _model_cache = {}
85
+
86
+ MODEL_CONFIGS = {
87
+ 'bert': 'bert-base-multilingual-cased',
88
+ 'xlmr': 'xlm-roberta-base',
89
+ 'xlmr-large': 'xlm-roberta-large',
90
+ }
91
+
92
+ def __init__(self, model: str = "bert", device: str = "cpu", layer: int = 8):
93
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
94
+
95
+ self.device = torch.device(device)
96
+ self.layer = layer
97
+
98
+ model_name = self.MODEL_CONFIGS.get(model, model)
99
+ self.model_name = model_name
100
+
101
+ cache_key = f"transformer_{model_name}_{device}"
102
+ if cache_key not in self._model_cache:
103
+ config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
104
+ emb_model = AutoModel.from_pretrained(model_name, config=config)
105
+ emb_model.eval()
106
+ emb_model.to(self.device)
107
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
108
+ self._model_cache[cache_key] = (emb_model, tokenizer)
109
+
110
+ self.emb_model, self.tokenizer = self._model_cache[cache_key]
111
+
112
+ def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
113
+ subwords = []
114
+ word_map = []
115
+ for i, word in enumerate(words):
116
+ tokens = self.tokenizer.tokenize(word)
117
+ subwords.extend(tokens)
118
+ word_map.extend([i] * len(tokens))
119
+ return subwords, word_map
120
+
121
+ def get_embeddings(self, tokens: List[str]) -> np.ndarray:
122
+ with torch.no_grad():
123
+ inputs = self.tokenizer(tokens, is_split_into_words=True,
124
+ padding=True, truncation=True, return_tensors="pt")
125
+ hidden = self.emb_model(**inputs.to(self.device))["hidden_states"]
126
+ if self.layer >= len(hidden):
127
+ raise ValueError(f"Layer {self.layer} requested but model has only {len(hidden)} layers.")
128
+ outputs = hidden[self.layer][:, 1:-1, :]
129
+ return outputs.cpu().numpy()
130
+
131
+ def get_embeddings_batch(self, batch: List[List[str]]) -> List[np.ndarray]:
132
+ """Get embeddings for a batch of token lists."""
133
+ with torch.no_grad():
134
+ inputs = self.tokenizer(batch, is_split_into_words=True,
135
+ padding=True, truncation=True, return_tensors="pt")
136
+ hidden = self.emb_model(**inputs.to(self.device))["hidden_states"]
137
+ outputs = hidden[self.layer][:, 1:-1, :]
138
+
139
+ results = []
140
+ for i, tokens in enumerate(batch):
141
+ subwords = []
142
+ for word in tokens:
143
+ subwords.extend(self.tokenizer.tokenize(word))
144
+ seq_len = len(subwords)
145
+ results.append(outputs[i, :seq_len].cpu().numpy())
146
+ return results
147
+
148
+
149
+ class OllamaEmbeddingProvider(EmbeddingProvider):
150
+ """Embedding provider using Ollama API (new /api/embed endpoint)."""
151
+
152
+ def __init__(
153
+ self,
154
+ model: str = "nomic-embed-text",
155
+ base_url: str = "http://localhost:11434",
156
+ ):
157
+ try:
158
+ import requests
159
+ except ImportError:
160
+ raise ImportError("requests package required for Ollama provider")
161
+
162
+ self.model = model
163
+ self.base_url = base_url.rstrip("/")
164
+ self._requests = requests
165
+
166
+ def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
167
+ return words, list(range(len(words)))
168
+
169
+ def get_embeddings(self, tokens: List[str]) -> np.ndarray:
170
+ if not tokens:
171
+ return np.empty((0, 0))
172
+
173
+ response = self._requests.post(
174
+ f"{self.base_url}/api/embed",
175
+ json={
176
+ "model": self.model,
177
+ "input": tokens,
178
+ },
179
+ timeout=60,
180
+ )
181
+ response.raise_for_status()
182
+
183
+ data = response.json()
184
+
185
+ # Ollama returns: { "embeddings": [ [...], [...], ... ] }
186
+ embeddings = data["embeddings"]
187
+ return np.array(embeddings, dtype=np.float32)
188
+
189
+
190
+
191
+ class SentenceTransformerProvider(EmbeddingProvider):
192
+ """Embedding provider using sentence-transformers."""
193
+
194
+ _model_cache = {}
195
+
196
+ def __init__(self, model: str = "paraphrase-multilingual-MiniLM-L12-v2", device: str = "cpu"):
197
+ try:
198
+ from sentence_transformers import SentenceTransformer
199
+ except ImportError:
200
+ raise ImportError("sentence-transformers package required")
201
+
202
+ self.device = device
203
+ cache_key = f"sbert_{model}"
204
+ if cache_key not in self._model_cache:
205
+ self._model_cache[cache_key] = SentenceTransformer(model, device=device)
206
+ self.model = self._model_cache[cache_key]
207
+
208
+ def get_subword_to_word_map(self, words: List[str]) -> Tuple[List[str], List[int]]:
209
+ return words, list(range(len(words)))
210
+
211
+ def get_embeddings(self, tokens: List[str]) -> np.ndarray:
212
+ return self.model.encode(tokens, convert_to_numpy=True)
213
+
214
+
215
+ # =============================================================================
216
+ # SENTENCE ALIGNER - Core alignment logic
217
+ # =============================================================================
218
+
219
+ class SentenceAligner:
220
+ """Word alignment using contextual embeddings and various matching algorithms."""
221
+
222
+ MATCHING_METHODS = {"a": "inter", "m": "mwmf", "i": "itermax", "f": "fwd", "r": "rev",
223
+ "g": "greedy", "t": "threshold"}
224
+
225
+ def __init__(self,
226
+ embedding_provider: Optional[EmbeddingProvider] = None,
227
+ model: str = "bert",
228
+ token_type: str = "bpe",
229
+ distortion: float = 0.0,
230
+ matching_methods: str = "mai",
231
+ device: str = "cpu",
232
+ layer: int = 8):
233
+ """
234
+ Initialize SentenceAligner.
235
+
236
+ Args:
237
+ embedding_provider: Optional pre-configured EmbeddingProvider instance.
238
+ model: Model name (used if embedding_provider is None)
239
+ token_type: "bpe" for subword alignment, "word" for word-level
240
+ distortion: Position distortion factor (0.0 = no distortion)
241
+ matching_methods: String of method codes
242
+ device: Device for computation
243
+ layer: Transformer layer to extract embeddings from
244
+ """
245
+ self.token_type = token_type
246
+ self.distortion = distortion
247
+ self.matching_methods = [self.MATCHING_METHODS[m] for m in matching_methods if m in self.MATCHING_METHODS]
248
+
249
+ if embedding_provider is not None:
250
+ self.embed_provider = embedding_provider
251
+ else:
252
+ self.embed_provider = TransformerEmbeddingProvider(model=model, device=device, layer=layer)
253
+
254
+ # -------------------------------------------------------------------------
255
+ # Static alignment algorithms
256
+ # -------------------------------------------------------------------------
257
+
258
+ @staticmethod
259
+ def get_similarity(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
260
+ return (cosine_similarity(X, Y) + 1.0) / 2.0
261
+
262
+ @staticmethod
263
+ def apply_distortion(sim_matrix: np.ndarray, ratio: float = 0.5) -> np.ndarray:
264
+ shape = sim_matrix.shape
265
+ if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
266
+ return sim_matrix
267
+ pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])])
268
+ pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])])
269
+ distortion_mask = 1.0 - ((pos_x - np.transpose(pos_y)) ** 2) * ratio
270
+ return np.multiply(sim_matrix, distortion_mask)
271
+
272
+ @staticmethod
273
+ def get_alignment_matrix(sim_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
274
+ m, n = sim_matrix.shape
275
+ forward = np.eye(n)[sim_matrix.argmax(axis=1)]
276
+ backward = np.eye(m)[sim_matrix.argmax(axis=0)]
277
+ return forward, backward.transpose()
278
+
279
+ @staticmethod
280
+ def get_max_weight_match(sim: np.ndarray) -> np.ndarray:
281
+ if nx is None:
282
+ raise ValueError("networkx must be installed to use mwmf algorithm.")
283
+ def permute(edge):
284
+ if edge[0] < sim.shape[0]:
285
+ return edge[0], edge[1] - sim.shape[0]
286
+ else:
287
+ return edge[1], edge[0] - sim.shape[0]
288
+ G = from_biadjacency_matrix(csr_matrix(sim))
289
+ matching = nx.max_weight_matching(G, maxcardinality=True)
290
+ matching = [permute(x) for x in matching]
291
+ res_matrix = np.zeros_like(sim)
292
+ for edge in matching:
293
+ res_matrix[edge[0], edge[1]] = 1
294
+ return res_matrix
295
+
296
+ @staticmethod
297
+ def iter_max(sim_matrix: np.ndarray, max_count: int = 2) -> np.ndarray:
298
+ alpha_ratio = 0.9
299
+ m, n = sim_matrix.shape
300
+ forward = np.eye(n)[sim_matrix.argmax(axis=1)]
301
+ backward = np.eye(m)[sim_matrix.argmax(axis=0)]
302
+ inter = forward * backward.transpose()
303
+
304
+ if min(m, n) <= 2:
305
+ return inter
306
+
307
+ count = 1
308
+ while count < max_count:
309
+ mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0)
310
+ mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0)
311
+ mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0)
312
+ mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y))
313
+ if mask_x.sum() < 1.0 or mask_y.sum() < 1.0:
314
+ mask *= 0.0
315
+ mask_zeros *= 0.0
316
+
317
+ new_sim = sim_matrix * mask
318
+ fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros
319
+ bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros
320
+ new_inter = fwd * bac
321
+
322
+ if np.array_equal(inter + new_inter, inter):
323
+ break
324
+ inter = inter + new_inter
325
+ count += 1
326
+ return inter
327
+
328
+ @staticmethod
329
+ def greedy_match(sim_matrix: np.ndarray, one_to_one: bool = True) -> np.ndarray:
330
+ m, n = sim_matrix.shape
331
+ result = np.zeros((m, n))
332
+ sim_flat = sim_matrix.flatten()
333
+ indices = np.argsort(-sim_flat)
334
+
335
+ used_src = set() if one_to_one else None
336
+ used_tgt = set() if one_to_one else None
337
+
338
+ for idx in indices:
339
+ i, j = idx // n, idx % n
340
+ if sim_matrix[i, j] <= 0:
341
+ break
342
+ if one_to_one:
343
+ if i in used_src or j in used_tgt:
344
+ continue
345
+ used_src.add(i)
346
+ used_tgt.add(j)
347
+ result[i, j] = 1
348
+ return result
349
+
350
+ @staticmethod
351
+ def threshold_match(sim_matrix: np.ndarray, threshold: float = 0.5) -> np.ndarray:
352
+ return (sim_matrix >= threshold).astype(float)
353
+
354
+ # -------------------------------------------------------------------------
355
+ # Main alignment methods
356
+ # -------------------------------------------------------------------------
357
+
358
+ def _prepare_tokens(self, sent: Union[str, List[str]]) -> Tuple[List[str], List[str], List[int]]:
359
+ """Prepare tokens and get subword mapping."""
360
+ if isinstance(sent, str):
361
+ words = sent.split()
362
+ else:
363
+ words = sent
364
+ subwords, word_map = self.embed_provider.get_subword_to_word_map(words)
365
+ return words, subwords, word_map
366
+
367
+ def _compute_alignments(self, sim: np.ndarray) -> Dict[str, np.ndarray]:
368
+ """Compute all requested alignment matrices."""
369
+ all_mats = {}
370
+ all_mats["fwd"], all_mats["rev"] = self.get_alignment_matrix(sim)
371
+ all_mats["inter"] = all_mats["fwd"] * all_mats["rev"]
372
+
373
+ if "mwmf" in self.matching_methods:
374
+ all_mats["mwmf"] = self.get_max_weight_match(sim)
375
+ if "itermax" in self.matching_methods:
376
+ all_mats["itermax"] = self.iter_max(sim)
377
+ if "greedy" in self.matching_methods:
378
+ all_mats["greedy"] = self.greedy_match(sim)
379
+ if "threshold" in self.matching_methods:
380
+ all_mats["threshold"] = self.threshold_match(sim)
381
+
382
+ return all_mats
383
+
384
+ def _average_embeds_over_words(self,
385
+ bpe_vectors: List[np.ndarray],
386
+ words: List[List[str]],
387
+ subwords: List[List[str]]) -> List[np.ndarray]:
388
+ """Average subword embeddings to get word-level embeddings."""
389
+ new_vectors = []
390
+ for lang_idx in range(2):
391
+ word_list = words[lang_idx]
392
+ _, word_map = self.embed_provider.get_subword_to_word_map(word_list)
393
+
394
+ word_vectors = []
395
+ for word_idx in range(len(word_list)):
396
+ subword_indices = [i for i, w in enumerate(word_map) if w == word_idx]
397
+ if subword_indices and max(subword_indices) < bpe_vectors[lang_idx].shape[0]:
398
+ word_vectors.append(bpe_vectors[lang_idx][subword_indices].mean(0))
399
+ else:
400
+ word_vectors.append(np.zeros(bpe_vectors[lang_idx].shape[1]))
401
+
402
+ new_vectors.append(np.array(word_vectors) if word_vectors else np.zeros((0, bpe_vectors[lang_idx].shape[1])))
403
+ return new_vectors
404
+
405
+ def get_text_embeddings(self, text: str) -> Tuple[List[TokenInfo], np.ndarray]:
406
+ """
407
+ Get embeddings for a text string.
408
+
409
+ Args:
410
+ text: Input text
411
+
412
+ Returns:
413
+ Tuple of (tokens, embeddings)
414
+ """
415
+ tokens = self.embed_provider.tokenize_text(text)
416
+ words = [t.text for t in tokens]
417
+
418
+ # Get subword mappings
419
+ subwords, word_map = self.embed_provider.get_subword_to_word_map(words)
420
+
421
+ # Get embeddings
422
+ if hasattr(self.embed_provider, 'get_embeddings_batch'):
423
+ vectors = self.embed_provider.get_embeddings_batch([words])[0]
424
+ else:
425
+ vectors = self.embed_provider.get_embeddings(words)
426
+
427
+ # Truncate to subword lengths
428
+ vectors = vectors[:len(subwords)]
429
+
430
+ # Average over words if needed
431
+ if self.token_type == "word":
432
+ vectors = self._average_embeds_over_words(
433
+ [vectors, vectors], # Hack: pass twice to satisfy list expectation
434
+ [words, words],
435
+ [subwords, subwords]
436
+ )[0]
437
+
438
+ return tokens, vectors
439
+
440
+
441
+ # Helper method to get embeddings and similarity
442
+ def _get_embeddings_and_similarity(self, src_words: List[str], tgt_words: List[str],
443
+ compute_sim: bool = True) -> Tuple[List[np.ndarray], Optional[np.ndarray]]:
444
+ """Get embeddings for source and target words, optionally compute similarity."""
445
+ # Get subword mappings
446
+ src_subwords, _ = self.embed_provider.get_subword_to_word_map(src_words)
447
+ tgt_subwords, _ = self.embed_provider.get_subword_to_word_map(tgt_words)
448
+
449
+ # Get embeddings
450
+ if hasattr(self.embed_provider, 'get_embeddings_batch'):
451
+ vectors = self.embed_provider.get_embeddings_batch([src_words, tgt_words])
452
+ else:
453
+ vectors = [
454
+ self.embed_provider.get_embeddings(src_words),
455
+ self.embed_provider.get_embeddings(tgt_words)
456
+ ]
457
+
458
+ # Truncate to subword lengths
459
+ vectors = [vectors[0][:len(src_subwords)], vectors[1][:len(tgt_subwords)]]
460
+
461
+ # Average over words if needed
462
+ if self.token_type == "word":
463
+ vectors = self._average_embeds_over_words(
464
+ vectors, [src_words, tgt_words], [src_subwords, tgt_subwords]
465
+ )
466
+
467
+ # Compute similarity if requested
468
+ sim = None
469
+ if compute_sim:
470
+ sim = self.get_similarity(vectors[0], vectors[1])
471
+ sim = self.apply_distortion(sim, self.distortion)
472
+
473
+ return vectors, sim
474
+
475
+ def _get_subword_range(self, word_start: int, word_end: int,
476
+ w2b: List[int]) -> Tuple[int, int]:
477
+ """Convert word range to subword range."""
478
+ if self.token_type == "bpe":
479
+ sub_start = next((i for i, w in enumerate(w2b) if w >= word_start), 0)
480
+ sub_end = next((i for i, w in enumerate(w2b) if w >= word_end), len(w2b))
481
+ return sub_start, sub_end
482
+ return word_start, word_end
483
+
484
+ def _get_alignments_from_similarity(self, sim: np.ndarray, vectors: List[np.ndarray],
485
+ src_w2b: List[int], tgt_w2b: List[int],
486
+ n_src_sub: int, n_tgt_sub: int,
487
+ n_src_word: int, n_tgt_word: int,
488
+ tgt_offset: int = 0) -> Dict[str, List]:
489
+ """
490
+ Convert similarity matrix to alignment tuples.
491
+
492
+ Args:
493
+ sim: Similarity matrix
494
+ vectors: Source and target embedding vectors
495
+ src_w2b: Source subword-to-word mapping
496
+ tgt_w2b: Target subword-to-word mapping
497
+ n_src_sub: Number of source subwords
498
+ n_tgt_sub: Number of target subwords
499
+ n_src_word: Number of source words
500
+ n_tgt_word: Number of target words
501
+ tgt_offset: Offset for target indices (for partial alignment)
502
+
503
+ Returns:
504
+ Dictionary mapping method names to sorted alignment lists
505
+ """
506
+ all_mats = self._compute_alignments(sim)
507
+ aligns = {method: set() for method in self.matching_methods}
508
+
509
+ n_src = n_src_sub if self.token_type == "bpe" else n_src_word
510
+ n_tgt = n_tgt_sub if self.token_type == "bpe" else n_tgt_word
511
+
512
+ for i in range(min(vectors[0].shape[0], n_src)):
513
+ for j in range(min(vectors[1].shape[0], n_tgt)):
514
+ actual_tgt_idx = j + tgt_offset
515
+
516
+ for method in self.matching_methods:
517
+ if method in all_mats and all_mats[method][i, j] > 0:
518
+ if self.token_type == "bpe":
519
+ src_idx = src_w2b[i] if i < len(src_w2b) else i
520
+ tgt_idx = tgt_w2b[actual_tgt_idx] if actual_tgt_idx < len(tgt_w2b) else actual_tgt_idx
521
+ else:
522
+ src_idx = i
523
+ tgt_idx = actual_tgt_idx
524
+
525
+ aligns[method].add((src_idx, tgt_idx))
526
+
527
+ # Convert sets to sorted lists
528
+ return {method: sorted(align_set) for method, align_set in aligns.items()}
529
+
530
+
531
+ # Alignment functions
532
+ def get_word_aligns(self,
533
+ src_sent: Union[str, List[str]],
534
+ trg_sent: Union[str, List[str]]) -> Tuple[Dict[str, List], List[np.ndarray], np.ndarray]:
535
+ """
536
+ Get word alignments between source and target sentences.
537
+ Legacy API - for backwards compatibility.
538
+ """
539
+ src_words, src_subwords, src_w2b = self._prepare_tokens(src_sent)
540
+ tgt_words, tgt_subwords, tgt_w2b = self._prepare_tokens(trg_sent)
541
+
542
+ vectors, sim = self._get_embeddings_and_similarity(src_words, tgt_words)
543
+
544
+ aligns = self._get_alignments_from_similarity(sim, vectors, src_w2b, tgt_w2b,
545
+ len(src_subwords), len(tgt_subwords),
546
+ len(src_words), len(tgt_words))
547
+
548
+ return aligns, vectors, sim
549
+
550
+
551
+
552
+ # Main alignment entry points
553
+ def align_texts(self, src_text: str, tgt_text: str) -> AlignmentResult:
554
+ """
555
+ Align two texts and return complete alignment result.
556
+
557
+ This is the main entry point that handles tokenization, embedding, and alignment.
558
+ All tokenization is done by the embedding provider for consistency.
559
+
560
+ Args:
561
+ src_text: Source text string
562
+ tgt_text: Target text string
563
+
564
+ Returns:
565
+ AlignmentResult with alignments, tokens, and similarity matrix
566
+ """
567
+ print("1.1 Computing alignments...")
568
+ result = self.align_texts_partial(src_text, tgt_text, src_char_start=0, src_char_end=None)
569
+ print("1.2 Alignments computed.")
570
+ return result
571
+
572
+ def align_texts_partial(self, src_text: str, tgt_text: str,
573
+ src_char_start: int = 0,
574
+ src_char_end: Optional[int] = None) -> AlignmentResult:
575
+ """
576
+ Align two texts with partial source range defined by character positions.
577
+
578
+ Args:
579
+ src_text: Source text string
580
+ tgt_text: Target text string
581
+ src_char_start: Start index (char position) in source text
582
+ src_char_end: End index (char position) in source text (None = end of text)
583
+
584
+ Returns:
585
+ AlignmentResult with alignments for the partial source range.
586
+ """
587
+ # compute embeddings
588
+ src_tokens, src_vectors = self.get_text_embeddings(src_text)
589
+ tgt_tokens, tgt_vectors = self.get_text_embeddings(tgt_text)
590
+
591
+ return self.align_texts_partial_with_embeddings(
592
+ src_tokens, tgt_tokens, src_vectors, tgt_vectors, src_char_start, src_char_end
593
+ )
594
+
595
+ def align_texts_partial_with_embeddings(self,
596
+ src_tokens: List[TokenInfo],
597
+ tgt_tokens: List[TokenInfo],
598
+ src_vectors: np.ndarray,
599
+ tgt_vectors: np.ndarray,
600
+ src_char_start: int,
601
+ src_char_end: Optional[int] = None) -> AlignmentResult:
602
+ """
603
+ Align partial source text to target text using pre-computed embeddings and character positions.
604
+ """
605
+ src_words = [t.text for t in src_tokens]
606
+ tgt_words = [t.text for t in tgt_tokens]
607
+
608
+ # Handle default end
609
+ if src_char_end is None:
610
+ if src_tokens:
611
+ src_char_end = src_tokens[-1].end
612
+ else:
613
+ src_char_end = 0
614
+
615
+
616
+ # Map char range to token range
617
+ src_start_idx = None
618
+ src_end_idx = None
619
+
620
+ # Simple mapping
621
+ for i, token in enumerate(src_tokens):
622
+ # Check for overlap
623
+ t_start = token.start
624
+ t_end = token.end
625
+
626
+ # Start index: first token that ends after char_start
627
+ if src_start_idx is None and t_end > src_char_start:
628
+ src_start_idx = i
629
+
630
+ # End index: last token that starts before char_end
631
+ if t_start < src_char_end:
632
+ src_end_idx = i + 1
633
+
634
+ if src_start_idx is None:
635
+ src_start_idx = 0
636
+ if src_end_idx is None:
637
+ src_end_idx = len(src_tokens)
638
+
639
+ # Logic from _align_partial_internal adapted:
640
+ src_subwords, src_w2b = self.embed_provider.get_subword_to_word_map(src_words)
641
+ tgt_subwords, tgt_w2b = self.embed_provider.get_subword_to_word_map(tgt_words)
642
+
643
+ # Determine source subword range
644
+ src_sub_start, src_sub_end = self._get_subword_range(src_start_idx, src_end_idx, src_w2b)
645
+
646
+ # Slice source vectors
647
+ src_partial_vecs = src_vectors[src_sub_start:src_sub_end]
648
+ src_w2b_partial = [idx - src_start_idx for idx in src_w2b[src_sub_start:src_sub_end]]
649
+
650
+ # Compute similarity
651
+ sim_partial = self.get_similarity(src_partial_vecs, tgt_vectors)
652
+ sim_partial = self.apply_distortion(sim_partial, self.distortion)
653
+
654
+ # Compute alignments
655
+ aligns = self._get_alignments_from_similarity(
656
+ sim_partial, [src_partial_vecs, tgt_vectors], src_w2b_partial, tgt_w2b,
657
+ len(src_w2b_partial), len(tgt_subwords),
658
+ src_end_idx - src_start_idx, len(tgt_words),
659
+ tgt_offset=0
660
+ )
661
+
662
+ partial_src_tokens = src_tokens[src_start_idx:src_end_idx]
663
+
664
+ return AlignmentResult(
665
+ alignments=aligns,
666
+ src_tokens=partial_src_tokens,
667
+ tgt_tokens=tgt_tokens,
668
+ similarity_matrix=sim_partial,
669
+ src_vectors=src_partial_vecs,
670
+ tgt_vectors=tgt_vectors
671
+ )
672
+
673
+ def align_texts_partial_substring(self, src_text: str, tgt_text: str,
674
+ src_substring: str) -> AlignmentResult:
675
+ """
676
+ Align two texts using a substring to define the source range.
677
+
678
+ This method finds the substring in the source text and aligns only that portion.
679
+
680
+ Args:
681
+ src_text: Source text string
682
+ tgt_text: Target text string
683
+ src_substring: Substring to find in source text
684
+
685
+ Returns:
686
+ AlignmentResult with alignments for the substring range
687
+
688
+ Raises:
689
+ ValueError: If substring is not found in source text
690
+ """
691
+ # Find substring position
692
+ substring_start = src_text.find(src_substring)
693
+ if substring_start == -1:
694
+ raise ValueError(f"Substring '{src_substring}' not found in source text")
695
+
696
+ substring_end = substring_start + len(src_substring)
697
+
698
+ # Use align_texts_partial with computed indices
699
+ return self.align_texts_partial(src_text, tgt_text, substring_start, substring_end)
700
+
701
+
702
+ # Debugging and visualization
703
+ def print_alignment(self, alignment_result: AlignmentResult, method: str = "inter") -> str:
704
+ """print alignments in a human-readable format showing the source and target tokens with their indices and the aligned pairs according to the specified method."""
705
+ if method not in alignment_result.alignments:
706
+ return f"Method '{method}' not found in alignment results."
707
+
708
+ src_tokens = alignment_result.src_tokens
709
+ tgt_tokens = alignment_result.tgt_tokens
710
+ aligns = alignment_result.alignments[method]
711
+
712
+ for alignment in aligns:
713
+ print(f" {alignment}: {src_tokens[alignment[0]].text} <-> {tgt_tokens[alignment[1]].text}")