OntoLearner 1.4.11__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,365 +11,524 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  from typing import Any, Dict, List, Optional, Tuple
16
-
15
+ import json
17
16
  import math
18
17
  import os
19
18
  import random
19
+ from datetime import datetime
20
20
  import torch
21
21
  import torch.nn as nn
22
22
  import torch.nn.functional as F
23
23
  from sentence_transformers import SentenceTransformer
24
+ from tqdm import tqdm
25
+ from torch.cuda.amp import GradScaler
24
26
 
25
27
  from ...base import AutoLearner
26
28
 
27
29
 
28
30
  class RMSNorm(nn.Module):
29
- """Root Mean Square normalization with learnable scale.
30
-
31
- Computes per-position normalization:
32
- y = weight * x / sqrt(mean(x^2) + eps)
33
-
34
- This variant normalizes over the last dimension and keeps scale as a
35
- learnable parameter, similar to RMSNorm used in modern transformer stacks.
36
- """
37
-
31
+ """Root Mean Square normalization with learnable scale."""
38
32
  def __init__(self, dim: int, eps: float = 1e-6):
39
- """Initialize the RMSNorm layer.
40
-
41
- Args:
42
- dim: Size of the last (feature) dimension to normalize over.
43
- eps: Small constant added inside the square root for numerical
44
- stability.
45
- """
46
33
  super().__init__()
47
34
  self.eps = eps
48
35
  self.weight = nn.Parameter(torch.ones(dim))
49
36
 
50
37
  def forward(self, x: torch.Tensor) -> torch.Tensor:
51
- """Apply RMS normalization.
52
-
53
- Args:
54
- x: Input tensor of shape (..., dim).
55
-
56
- Returns:
57
- Tensor of the same shape as `x`, RMS-normalized over the last axis.
58
- """
59
38
  rms_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
60
39
  return self.weight * (x * rms_inv)
61
40
 
62
-
63
41
  class CrossAttentionHead(nn.Module):
64
- """Minimal multi-head *pair* scorer using cross-attention-style projections.
65
-
66
- Given child vector `c` and parent vector `p`:
67
- q = W_q * c, k = W_k * p
68
- score_head = (q_h · k_h) / sqrt(d_head)
69
-
70
- We average the per-head scores and apply a sigmoid to produce a probability.
71
- This is not a full attention block—just a learnable similarity function.
72
- """
73
-
74
- def __init__(
75
- self, hidden_size: int, num_heads: int = 8, rms_norm_eps: float = 1e-6
76
- ):
77
- """Initialize projections and per-stream normalizers.
78
-
79
- Args:
80
- hidden_size: Dimensionality of input embeddings (child/parent).
81
- num_heads: Number of subspaces to split the projection into.
82
- rms_norm_eps: Epsilon for RMSNorm stability.
83
-
84
- Raises:
85
- AssertionError: If `hidden_size` is not divisible by `num_heads`.
86
- """
42
+ """Efficient multi-head cross-attention scorer for parent-child pairs."""
43
+ def __init__(self, hidden_size: int, num_heads: int = 8, rms_norm_eps: float = 1e-6, dropout: float = 0.1):
87
44
  super().__init__()
88
- assert hidden_size % num_heads == 0, (
89
- "hidden_size must be divisible by num_heads"
90
- )
45
+ assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
91
46
  self.hidden_size = hidden_size
92
47
  self.num_heads = num_heads
93
48
  self.dim_per_head = hidden_size // num_heads
94
49
 
95
- # Linear projections for queries (child) and keys (parent)
96
50
  self.query_projection = nn.Linear(hidden_size, hidden_size, bias=False)
97
51
  self.key_projection = nn.Linear(hidden_size, hidden_size, bias=False)
98
52
 
99
- # Pre-projection normalization for stability
100
53
  self.query_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
101
54
  self.key_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
102
55
 
103
- # Xavier init helps stabilize training
56
+ self.dropout = nn.Dropout(dropout)
57
+
104
58
  nn.init.xavier_uniform_(self.query_projection.weight)
105
59
  nn.init.xavier_uniform_(self.key_projection.weight)
106
60
 
107
- def forward(
108
- self, child_embeddings: torch.Tensor, parent_embeddings: torch.Tensor
109
- ) -> torch.Tensor:
110
- """Score (child, parent) pairs.
61
+ def forward(self, child_embeddings: torch.Tensor, parent_embeddings: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Score (child, parent) pairs efficiently.
111
64
 
112
65
  Args:
113
- child_embeddings: Tensor of shape (batch, hidden_size).
114
- parent_embeddings: Tensor of shape (batch, hidden_size).
66
+ child_embeddings: (batch_child, hidden_size) or (1, n_terms, hidden_size) for broadcasting
67
+ parent_embeddings: (batch_parent, hidden_size) or (1, n_terms, hidden_size) for broadcasting
115
68
 
116
69
  Returns:
117
- Tensor of probabilities with shape (batch,), each in [0, 1].
70
+ scores: (batch_child, batch_parent) if both are 2D, or appropriate broadcast shape
118
71
  """
119
- batch_size, _ = child_embeddings.shape
72
+ # Handle 2D input (standard batch processing)
73
+ if child_embeddings.dim() == 2 and parent_embeddings.dim() == 2:
74
+ batch_size = child_embeddings.shape[0]
75
+ queries = self.query_norm(self.query_projection(child_embeddings))
76
+ keys = self.key_norm(self.key_projection(parent_embeddings))
120
77
 
121
- # Project and normalize
122
- queries = self.query_norm(self.query_projection(child_embeddings))
123
- keys = self.key_norm(self.key_projection(parent_embeddings))
78
+ queries = self.dropout(queries)
79
+ keys = self.dropout(keys)
124
80
 
125
- # Reshape into heads: (batch, heads, dim_per_head)
126
- queries = queries.view(batch_size, self.num_heads, self.dim_per_head)
127
- keys = keys.view(batch_size, self.num_heads, self.dim_per_head)
81
+ queries = queries.view(batch_size, self.num_heads, self.dim_per_head)
82
+ keys = keys.view(batch_size, self.num_heads, self.dim_per_head)
128
83
 
129
- # Scaled dot-product similarity per head -> (batch, heads)
130
- per_head_scores = (queries * keys).sum(-1) / math.sqrt(self.dim_per_head)
84
+ per_head_scores = (queries * keys).sum(-1) / math.sqrt(self.dim_per_head)
85
+ mean_score = per_head_scores.mean(-1)
86
+ return torch.sigmoid(mean_score)
131
87
 
132
- # Aggregate across heads -> (batch,)
133
- mean_score = per_head_scores.mean(-1)
88
+ # Handle 3D input for efficient matrix computation
89
+ elif child_embeddings.dim() == 3 and parent_embeddings.dim() == 3:
90
+ n_child = child_embeddings.shape[1]
91
+ n_parent = parent_embeddings.shape[1]
134
92
 
135
- # Map to probability
136
- return torch.sigmoid(mean_score)
93
+ queries = self.query_norm(self.query_projection(child_embeddings))
94
+ keys = self.key_norm(self.key_projection(parent_embeddings))
137
95
 
96
+ queries = self.dropout(queries)
97
+ keys = self.dropout(keys)
138
98
 
139
- class AlexbekCrossAttnLearner(AutoLearner):
140
- """Cross-Attention Taxonomy Learner (inherits AutoLearner).
99
+ queries = queries.view(1, n_child, self.num_heads, self.dim_per_head)
100
+ keys = keys.view(1, n_parent, self.num_heads, self.dim_per_head)
141
101
 
142
- Workflow
143
- - Encode terms with a SentenceTransformer.
144
- - Train a compact cross-attention head on (parent, child) pairs
145
- (positives + sampled negatives) using BCE loss.
146
- - Inference returns probabilities per pair; edges with prob >= 0.5 are
147
- labeled as positive.
102
+ queries = queries.squeeze(0).transpose(1, 2)
103
+ keys = keys.squeeze(0).transpose(1, 2)
148
104
 
105
+ per_head_scores = torch.einsum('chd,phd->cph', queries, keys) / math.sqrt(self.dim_per_head)
106
+ mean_score = per_head_scores.mean(-1)
107
+ return torch.sigmoid(mean_score)
108
+
109
+
110
+ class AlexbekCrossAttnLearner(AutoLearner):
149
111
  """
112
+ Cross-Attention Taxonomy Learner - faithful reproduction of Alexbek's approach.
150
113
 
114
+ This implementation follows the original paper's methodology:
115
+ - Computes full NxN pairwise scores for all term pairs
116
+ - Uses threshold-based prediction (0.5 default, or F1-optimized on validation)
117
+ - No candidate pre-filtering (can be optionally enabled for large taxonomies)
118
+ """
151
119
  def __init__(
152
- self,
153
- *,
154
- embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
155
- device: str = "cpu",
156
- num_heads: int = 8,
157
- lr: float = 5e-5,
158
- weight_decay: float = 0.01,
159
- num_epochs: int = 1,
160
- batch_size: int = 256,
161
- neg_ratio: float = 1.0, # negatives per positive
162
- output_dir: str = "./results/",
163
- seed: int = 42,
164
- **kwargs: Any,
120
+ self,
121
+ *,
122
+ embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
123
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
124
+ num_heads: int = 8,
125
+ dropout: float = 0.1,
126
+ lr: float = 1e-4,
127
+ weight_decay: float = 0.01,
128
+ num_epochs: int = 10,
129
+ batch_size: int = 256,
130
+ inference_batch_size: int = 512,
131
+ neg_ratio: float = 1.0,
132
+ top_k_candidates: Optional[int] = None, # None = original behavior (all pairs)
133
+ output_dir: str = "./results/",
134
+ seed: int = 42,
135
+ cache_embeddings: bool = True,
136
+ use_lr_scheduler: bool = True,
137
+ warmup_epochs: int = 1,
138
+ gradient_clip: float = 1.0,
139
+ use_amp: bool = True,
140
+ hard_negative_ratio: float = 0.0, # 0.0 = original (all random negatives)
141
+ patience: int = 3,
142
+ validation_split: float = 0.1,
143
+ normalize_embeddings: bool = True,
144
+ prediction_threshold: float = 0.5, # Original uses 0.5 or F1-optimized
145
+ optimize_threshold_on_val: bool = True, # Set True to replicate "Validation-F1" approach
146
+ **kwargs: Any,
165
147
  ):
166
- """Configure the learner.
167
-
168
- Args:
169
- embedding_model: SentenceTransformer model id/path for term encoding.
170
- device: 'cuda' or 'cpu'. If 'cuda' is requested but unavailable, CPU
171
- is used.
172
- num_heads: Number of heads in the cross-attention scorer.
173
- lr: Learning rate for AdamW.
174
- weight_decay: Weight decay for AdamW.
175
- num_epochs: Number of epochs to train the head.
176
- batch_size: Minibatch size for training and scoring loops.
177
- neg_ratio: Number of sampled negatives per positive during training.
178
- output_dir: Directory to store artifacts (reserved for future use).
179
- seed: Random seed for reproducibility.
180
- **kwargs: Passed through to `AutoLearner` base init.
181
-
182
- Side Effects:
183
- Creates `output_dir` if missing and seeds Python/Torch RNGs.
184
- """
185
148
  super().__init__(**kwargs)
186
149
 
187
- # hyperparameters / settings
188
150
  self.embedding_model_id = embedding_model
189
151
  self.requested_device = device
190
152
  self.num_heads = num_heads
153
+ self.dropout = dropout
191
154
  self.learning_rate = lr
192
155
  self.weight_decay = weight_decay
193
156
  self.num_epochs = num_epochs
194
157
  self.batch_size = batch_size
158
+ self.inference_batch_size = inference_batch_size
195
159
  self.negative_ratio = neg_ratio
160
+ self.top_k_candidates = top_k_candidates
196
161
  self.output_dir = output_dir
197
162
  self.seed = seed
163
+ self.cache_embeddings = cache_embeddings
164
+ self.use_lr_scheduler = use_lr_scheduler
165
+ self.warmup_epochs = warmup_epochs
166
+ self.gradient_clip = gradient_clip
167
+ self.use_amp = use_amp and torch.cuda.is_available()
168
+ self.hard_negative_ratio = hard_negative_ratio
169
+ self.patience = patience
170
+ self.validation_split = validation_split
171
+ self.normalize_embeddings = normalize_embeddings
172
+ self.prediction_threshold = prediction_threshold
173
+ self.optimize_threshold_on_val = optimize_threshold_on_val
198
174
 
199
- # Prefer requested device but gracefully fall back to CPU
200
175
  if torch.cuda.is_available() or self.requested_device == "cpu":
201
176
  self.device = torch.device(self.requested_device)
202
177
  else:
203
178
  self.device = torch.device("cpu")
204
179
 
205
- # Will be set in load()
206
180
  self.embedder: Optional[SentenceTransformer] = None
207
181
  self.cross_attn_head: Optional[CrossAttentionHead] = None
208
182
  self.embedding_dim: Optional[int] = None
209
-
210
- # Cache of term -> embedding tensor (on device)
211
183
  self.term_to_vector: Dict[str, torch.Tensor] = {}
184
+ self.scaler: Optional[GradScaler] = GradScaler() if self.use_amp else None
185
+ self.best_threshold: float = self.prediction_threshold
212
186
 
213
187
  os.makedirs(self.output_dir, exist_ok=True)
214
188
  random.seed(self.seed)
215
189
  torch.manual_seed(self.seed)
190
+ if torch.cuda.is_available():
191
+ torch.cuda.manual_seed_all(self.seed)
216
192
 
217
193
  def load(self, **kwargs: Any):
218
- """Load the sentence embedding model and initialize the cross-attention head.
219
-
220
- Args:
221
- **kwargs: Optional override, supports `embedding_model`.
222
-
223
- Side Effects:
224
- - Initializes `self.embedder` on the configured device.
225
- - Probes and stores `self.embedding_dim`.
226
- - Constructs `self.cross_attn_head` with the probed dimensionality.
227
- """
194
+ """Load the sentence embedding model and initialize the cross-attention head."""
228
195
  model_id = kwargs.get("embedding_model", self.embedding_model_id)
229
- self.embedder = SentenceTransformer(
230
- model_id, trust_remote_code=True, device=str(self.device)
231
- )
196
+ self.embedder = SentenceTransformer(model_id, trust_remote_code=True, device=str(self.device))
232
197
 
233
- # Probe output dimensionality using a dummy encode
234
- probe_embedding = self.embedder.encode(
235
- ["_dim_probe_"], convert_to_tensor=True, normalize_embeddings=False
236
- )
198
+ probe_embedding = self.embedder.encode(["_dim_probe_"],
199
+ convert_to_tensor=True,
200
+ normalize_embeddings=False)
237
201
  self.embedding_dim = int(probe_embedding.shape[-1])
238
202
 
239
- # Initialize the cross-attention head
240
203
  self.cross_attn_head = CrossAttentionHead(
241
- hidden_size=self.embedding_dim, num_heads=self.num_heads
204
+ hidden_size=self.embedding_dim,
205
+ num_heads=self.num_heads,
206
+ dropout=self.dropout
242
207
  ).to(self.device)
243
208
 
244
- def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
245
- """Train or infer taxonomy edges according to the AutoLearner contract.
209
+ def save_model(self, path: str) -> None:
210
+ """Save the trained cross-attention head."""
211
+ if self.cross_attn_head is None:
212
+ raise RuntimeError("No model to save")
213
+
214
+ checkpoint = {
215
+ 'model_state_dict': self.cross_attn_head.state_dict(),
216
+ 'embedding_dim': self.embedding_dim,
217
+ 'num_heads': self.num_heads,
218
+ 'dropout': self.dropout,
219
+ 'embedding_model_id': self.embedding_model_id,
220
+ 'best_threshold': self.best_threshold,
221
+ }
222
+
223
+ torch.save(checkpoint, path)
224
+ print(f"Model saved to {path}")
225
+
226
+ def load_model(self, path: str) -> None:
227
+ """Load a trained cross-attention head."""
228
+ checkpoint = torch.load(path, map_location=self.device)
229
+
230
+ self.embedding_dim = checkpoint['embedding_dim']
231
+ self.num_heads = checkpoint['num_heads']
232
+ self.dropout = checkpoint.get('dropout', 0.1)
233
+ self.embedding_model_id = checkpoint.get('embedding_model_id', self.embedding_model_id)
234
+ self.best_threshold = checkpoint.get('best_threshold', 0.5)
235
+
236
+ # Load embedder if not already loaded
237
+ if self.embedder is None:
238
+ self.embedder = SentenceTransformer(
239
+ self.embedding_model_id,
240
+ trust_remote_code=True,
241
+ device=str(self.device)
242
+ )
246
243
 
247
- Training (`test=False`)
248
- - Extract positives (parent, child) and the unique term set from `data`.
249
- - Build/extend the term embedding cache.
250
- - Sample negatives at ratio `self.negative_ratio`.
251
- - Train the cross-attention head with BCE loss.
244
+ self.cross_attn_head = CrossAttentionHead(
245
+ hidden_size=self.embedding_dim,
246
+ num_heads=self.num_heads,
247
+ dropout=self.dropout
248
+ ).to(self.device)
252
249
 
253
- Inference (`test=True`)
254
- - Ensure embeddings exist for all terms.
255
- - Score candidate pairs and return per-pair probabilities and labels.
250
+ self.cross_attn_head.load_state_dict(checkpoint['model_state_dict'])
251
+ print(f"Model loaded from {path} (threshold: {self.best_threshold:.3f})")
252
+
253
+ def save_config(self, path: str) -> None:
254
+ """Save hyperparameters to JSON."""
255
+ config = {
256
+ 'embedding_model': self.embedding_model_id,
257
+ 'num_heads': self.num_heads,
258
+ 'dropout': self.dropout,
259
+ 'lr': self.learning_rate,
260
+ 'weight_decay': self.weight_decay,
261
+ 'num_epochs': self.num_epochs,
262
+ 'batch_size': self.batch_size,
263
+ 'inference_batch_size': self.inference_batch_size,
264
+ 'negative_ratio': self.negative_ratio,
265
+ 'top_k_candidates': self.top_k_candidates,
266
+ 'use_lr_scheduler': self.use_lr_scheduler,
267
+ 'warmup_epochs': self.warmup_epochs,
268
+ 'gradient_clip': self.gradient_clip,
269
+ 'use_amp': self.use_amp,
270
+ 'hard_negative_ratio': self.hard_negative_ratio,
271
+ 'patience': self.patience,
272
+ 'validation_split': self.validation_split,
273
+ 'normalize_embeddings': self.normalize_embeddings,
274
+ 'prediction_threshold': self.prediction_threshold,
275
+ 'optimize_threshold_on_val': self.optimize_threshold_on_val,
276
+ 'best_threshold': self.best_threshold,
277
+ 'seed': self.seed,
278
+ }
279
+
280
+ with open(path, 'w') as f:
281
+ json.dump(config, f, indent=2)
282
+ print(f"Configuration saved to {path}")
283
+
284
+ @classmethod
285
+ def from_config(cls, config_path: str, **override_kwargs):
286
+ """Load from configuration file."""
287
+ with open(config_path, 'r') as f:
288
+ config = json.load(f)
289
+
290
+ config.update(override_kwargs)
291
+ return cls(**config)
256
292
 
257
- Args:
258
- data: Ontology-like object exposing `type_taxonomies.taxonomies`,
259
- where each item has `.parent` and `.child` string-like fields.
260
- test: If True, perform inference instead of training.
293
+ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
294
+ """
295
+ Train or infer taxonomy edges.
261
296
 
262
- Returns:
263
- - `None` on training.
264
- - On inference: List of dicts
265
- `{"parent": str, "child": str, "score": float, "label": int}`.
297
+ Original behavior: Scores ALL possible pairs and applies threshold.
298
+ Optional optimization: Can pre-filter to top-k candidates if top_k_candidates is set.
266
299
  """
267
300
  if self.embedder is None or self.cross_attn_head is None:
268
301
  self.load()
269
302
 
270
303
  if not test:
271
- positive_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(
272
- data
273
- )
304
+ # Training mode
305
+ positive_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(data, test=test)
274
306
  self._ensure_term_embeddings(unique_terms)
275
- negative_pairs = self._sample_negative_pairs(
276
- positive_pairs, unique_terms, ratio=self.negative_ratio, seed=self.seed
277
- )
307
+ negative_pairs = self._sample_negative_pairs(positive_pairs,
308
+ unique_terms,
309
+ ratio=self.negative_ratio,
310
+ seed=self.seed)
278
311
  self._train_cross_attn_head(positive_pairs, negative_pairs)
312
+
313
+ # Save model and config after training
314
+ model_path = os.path.join(self.output_dir, "best_model.pt")
315
+ config_path = os.path.join(self.output_dir, "config.json")
316
+ self.save_model(model_path)
317
+ self.save_config(config_path)
318
+
279
319
  return None
280
320
  else:
281
- candidate_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(
282
- data
283
- )
321
+ # Inference mode
322
+ candidate_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(data, test=test)
284
323
  self._ensure_term_embeddings(unique_terms, append_only=True)
285
- probabilities = self._score_parent_child_pairs(candidate_pairs)
286
-
287
- predictions = [
288
- {
289
- "parent": parent,
290
- "child": child,
291
- "score": float(prob),
292
- "label": int(prob >= 0.5),
293
- }
294
- for (parent, child), prob in zip(candidate_pairs, probabilities)
295
- ]
296
- return predictions
297
324
 
298
- def _ensure_term_embeddings(
299
- self, terms: List[str], append_only: bool = False
300
- ) -> None:
301
- """Encode terms with the sentence embedder and store in cache.
325
+ # Original approach: score all pairs
326
+ if self.top_k_candidates is None:
327
+ print(f"ORIGINAL MODE: Computing full {len(unique_terms)}x{len(unique_terms)} probability matrix...")
328
+ probabilities = self._score_all_pairs_efficient(unique_terms)
329
+
330
+ # Apply threshold to get predictions
331
+ print(f"Applying threshold {self.best_threshold:.3f} to extract predictions...")
332
+ binary_matrix = (probabilities >= self.best_threshold).cpu().numpy()
333
+
334
+ # Find indices where prediction is 1
335
+ child_indices, parent_indices = binary_matrix.nonzero()
336
+
337
+ # Get corresponding probabilities
338
+ probs = probabilities[child_indices, parent_indices].cpu().numpy()
339
+
340
+ # Build predictions
341
+ predictions = [
342
+ {
343
+ "parent": unique_terms[parent_idx],
344
+ "child": unique_terms[child_idx],
345
+ "score": float(prob),
346
+ "label": 1,
347
+ }
348
+ for child_idx, parent_idx, prob in zip(child_indices, parent_indices, probs)
349
+ if child_idx != parent_idx # Exclude self-loops
350
+ ]
351
+
352
+ print(
353
+ f"Found {len(predictions)} positive predictions from {len(unique_terms) ** 2 - len(unique_terms)} possible pairs")
354
+
355
+ else:
356
+ # Optional optimization: pre-filter candidates
357
+ print(f"OPTIMIZATION MODE: Filtering to top-{self.top_k_candidates} candidates per term...")
358
+ print("WARNING: This is NOT the original Alexbek approach but an efficiency optimization.")
359
+ candidate_pairs = self._filter_top_k_candidates(unique_terms, self.top_k_candidates)
360
+ print(f"Reduced to {len(candidate_pairs)} candidate pairs")
361
+
362
+ # Score filtered candidates
363
+ print("Scoring filtered candidate pairs...")
364
+ probabilities = self._score_specific_pairs(candidate_pairs)
365
+
366
+ # Apply threshold
367
+ predictions = [
368
+ {
369
+ "parent": parent,
370
+ "child": child,
371
+ "score": float(prob),
372
+ "label": 1,
373
+ }
374
+ for (parent, child), prob in zip(candidate_pairs, probabilities)
375
+ if prob >= self.best_threshold and parent != child
376
+ ]
377
+
378
+ print(f"Found {len(predictions)} positive predictions from {len(candidate_pairs)} candidate pairs")
302
379
 
303
- Args:
304
- terms: List of unique term strings to embed.
305
- append_only: If True, only embed terms missing from the cache;
306
- otherwise (re)encode all provided terms.
380
+ return predictions
307
381
 
308
- Raises:
309
- RuntimeError: If called before `load()`.
310
- """
382
+ def _ensure_term_embeddings(self, terms: List[str], append_only: bool = False) -> None:
383
+ """Encode terms efficiently with batching."""
311
384
  if self.embedder is None:
312
385
  raise RuntimeError("Call load() before building term embeddings")
313
386
 
314
- terms_to_encode = (
315
- [t for t in terms if t not in self.term_to_vector] if append_only else terms
316
- )
387
+ terms_to_encode = ([t for t in terms if t not in self.term_to_vector] if append_only else terms)
317
388
  if not terms_to_encode:
318
389
  return
319
390
 
391
+ # Batch encode terms with normalization
320
392
  embeddings = self.embedder.encode(
321
393
  terms_to_encode,
322
394
  convert_to_tensor=True,
323
- normalize_embeddings=False,
324
- batch_size=256,
325
- show_progress_bar=False,
395
+ normalize_embeddings=self.normalize_embeddings,
396
+ batch_size=self.inference_batch_size,
397
+ show_progress_bar=True,
326
398
  )
399
+
327
400
  for term, embedding in zip(terms_to_encode, embeddings):
328
- self.term_to_vector[term] = embedding.detach().to(self.device)
401
+ self.term_to_vector[term] = embedding.to(self.device)
329
402
 
330
- def _pairs_as_tensors(
331
- self, pairs: List[Tuple[str, str]]
332
- ) -> Tuple[torch.Tensor, torch.Tensor]:
333
- """Convert string pairs into aligned embedding tensors on the correct device.
403
+ def _score_specific_pairs(self, pairs: List[Tuple[str, str]]) -> List[float]:
404
+ """
405
+ Score only specific (parent, child) pairs efficiently in batches.
334
406
 
335
407
  Args:
336
- pairs: List of (parent, child) term strings.
408
+ pairs: List of (parent, child) tuples to score
337
409
 
338
410
  Returns:
339
- Tuple `(child_tensor, parent_tensor)` where each tensor has shape
340
- `(batch, embedding_dim)` and is located on `self.device`.
411
+ List of probability scores corresponding to input pairs
412
+ """
413
+ if self.cross_attn_head is None:
414
+ raise RuntimeError("Head not initialized. Call load().")
415
+
416
+ self.cross_attn_head.eval()
417
+ scores: List[float] = []
418
+
419
+ with torch.no_grad():
420
+ for start in tqdm(range(0, len(pairs), self.inference_batch_size), desc="Scoring pairs"):
421
+ chunk = pairs[start: start + self.inference_batch_size]
422
+ child_tensor, parent_tensor = self._pairs_as_tensors(chunk)
423
+ prob = self.cross_attn_head(child_tensor, parent_tensor)
424
+ scores.extend(prob.detach().cpu().tolist())
341
425
 
342
- Notes:
343
- This function assumes that all terms in `pairs` are present in
344
- `self.term_to_vector`. Use `_ensure_term_embeddings` beforehand.
426
+ return scores
427
+
428
+ def _score_all_pairs_efficient(self, terms: List[str]) -> torch.Tensor:
345
429
  """
346
- # child embeddings tensor of shape (batch, dim)
347
- child_tensor = torch.stack(
348
- [self.term_to_vector[child] for (_, child) in pairs], dim=0
349
- ).to(self.device)
350
- # parent embeddings tensor of shape (batch, dim)
351
- parent_tensor = torch.stack(
352
- [self.term_to_vector[parent] for (parent, _) in pairs], dim=0
353
- ).to(self.device)
354
- return child_tensor, parent_tensor
430
+ Efficiently score all pairs using matrix operations (ORIGINAL ALEXBEK APPROACH).
355
431
 
356
- def _train_cross_attn_head(
357
- self,
358
- positive_pairs: List[Tuple[str, str]],
359
- negative_pairs: List[Tuple[str, str]],
360
- ) -> None:
361
- """Train the cross-attention head with BCE loss on labeled pairs.
432
+ Returns:
433
+ scores: (n_terms, n_terms) matrix where scores[i,j] = P(terms[j] is parent of terms[i])
434
+ """
435
+ if self.cross_attn_head is None:
436
+ raise RuntimeError("Head not initialized. Call load().")
362
437
 
363
- The dataset is a concatenation of positives (label 1) and sampled
364
- negatives (label 0). The head is optimized with AdamW.
438
+ self.cross_attn_head.eval()
439
+ n_terms = len(terms)
365
440
 
366
- Args:
367
- positive_pairs: List of ground-truth (parent, child) edges.
368
- negative_pairs: List of sampled non-edges.
441
+ # Stack all embeddings
442
+ all_embeddings = torch.stack([self.term_to_vector[t] for t in terms], dim=0)
443
+
444
+ # Compute scores in chunks to manage memory
445
+ scores_matrix = torch.zeros((n_terms, n_terms), device=self.device)
446
+
447
+ with torch.no_grad():
448
+ chunk_size = self.inference_batch_size
449
+
450
+ progress_bar = tqdm(
451
+ range(0, n_terms, chunk_size),
452
+ desc="Scoring all pairs",
453
+ total=(n_terms + chunk_size - 1) // chunk_size
454
+ )
455
+
456
+ for i in progress_bar:
457
+ end_i = min(i + chunk_size, n_terms)
458
+ child_chunk = all_embeddings[i:end_i]
459
+
460
+ # Score against all parents at once
461
+ child_broadcast = child_chunk.unsqueeze(0)
462
+ parent_broadcast = all_embeddings.unsqueeze(0)
463
+
464
+ chunk_scores = self.cross_attn_head(child_broadcast, parent_broadcast)
465
+ scores_matrix[i:end_i, :] = chunk_scores
369
466
 
370
- Raises:
371
- RuntimeError: If the head has not been initialized (call `load()`).
467
+ progress_bar.set_postfix({
468
+ 'completed': f'{end_i}/{n_terms}',
469
+ 'pairs_scored': f'{end_i * n_terms:,}'
470
+ })
471
+
472
+ return scores_matrix
473
+
474
+ def _pairs_as_tensors(self, pairs: List[Tuple[str, str]]) -> Tuple[torch.Tensor, torch.Tensor]:
475
+ """Convert string pairs into aligned embedding tensors."""
476
+ child_tensor = torch.stack([self.term_to_vector[child] for (_, child) in pairs], dim=0)
477
+ parent_tensor = torch.stack([self.term_to_vector[parent] for (parent, _) in pairs], dim=0)
478
+ return child_tensor, parent_tensor
479
+
480
+ def _optimize_threshold_on_validation(self, val_pairs: List[Tuple[int, Tuple[str, str]]]) -> float:
481
+ """
482
+ Find optimal threshold that maximizes F1 on validation set.
483
+ This replicates the "Validation-F1" approach from the paper.
372
484
  """
485
+ if not val_pairs:
486
+ return self.prediction_threshold
487
+
488
+ print("Optimizing prediction threshold on validation set...")
489
+ self.cross_attn_head.eval()
490
+
491
+ # Get validation labels and scores
492
+ val_labels = []
493
+ val_scores = []
494
+
495
+ with torch.no_grad():
496
+ for label, (parent, child) in val_pairs:
497
+ val_labels.append(label)
498
+ child_tensor = self.term_to_vector[child].unsqueeze(0)
499
+ parent_tensor = self.term_to_vector[parent].unsqueeze(0)
500
+ score = self.cross_attn_head(child_tensor, parent_tensor).item()
501
+ val_scores.append(score)
502
+
503
+ val_labels = torch.tensor(val_labels)
504
+ val_scores = torch.tensor(val_scores)
505
+
506
+ # Try different thresholds
507
+ best_f1 = 0.0
508
+ best_threshold = 0.5
509
+
510
+ for threshold in torch.linspace(0.1, 0.9, 50):
511
+ preds = (val_scores >= threshold).long()
512
+
513
+ tp = ((preds == 1) & (val_labels == 1)).sum().item()
514
+ fp = ((preds == 1) & (val_labels == 0)).sum().item()
515
+ fn = ((preds == 0) & (val_labels == 1)).sum().item()
516
+
517
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
518
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
519
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
520
+
521
+ if f1 > best_f1:
522
+ best_f1 = f1
523
+ best_threshold = threshold.item()
524
+
525
+ print(f"Optimal threshold: {best_threshold:.3f} (F1: {best_f1:.4f})")
526
+ return best_threshold
527
+
528
+ def _train_cross_attn_head(self,
529
+ positive_pairs: List[Tuple[str, str]],
530
+ negative_pairs: List[Tuple[str, str]]) -> None:
531
+ """Train the cross-attention head with BCE loss, validation, and early stopping."""
373
532
  if self.cross_attn_head is None:
374
533
  raise RuntimeError("Head not initialized. Call load().")
375
534
 
@@ -380,121 +539,284 @@ class AlexbekCrossAttnLearner(AutoLearner):
380
539
  weight_decay=self.weight_decay,
381
540
  )
382
541
 
383
- # Build a simple supervised dataset: 1 for positive, 0 for negative
384
- labeled_pairs: List[Tuple[int, Tuple[str, str]]] = [
385
- (1, pc) for pc in positive_pairs
386
- ] + [(0, nc) for nc in negative_pairs]
542
+ # Prepare labeled pairs and split into train/val
543
+ labeled_pairs: List[Tuple[int, Tuple[str, str]]] = [(1, pc) for pc in positive_pairs] + \
544
+ [(0, nc) for nc in negative_pairs]
387
545
  random.shuffle(labeled_pairs)
388
546
 
389
- def iterate_minibatches(
390
- items: List[Tuple[int, Tuple[str, str]]], batch_size: int
391
- ):
392
- """Yield contiguous minibatches of size `batch_size` from `items`."""
393
- for start in range(0, len(items), batch_size):
394
- yield items[start : start + batch_size]
547
+ split_idx = int((1 - self.validation_split) * len(labeled_pairs))
548
+ train_pairs = labeled_pairs[:split_idx]
549
+ val_pairs = labeled_pairs[split_idx:]
550
+
551
+ print(f"Training samples: {len(train_pairs)}, Validation samples: {len(val_pairs)}")
552
+
553
+ # Setup learning rate scheduler
554
+ scheduler = None
555
+ if self.use_lr_scheduler:
556
+ total_steps = (len(train_pairs) // self.batch_size + 1) * self.num_epochs
557
+ warmup_steps = (len(train_pairs) // self.batch_size + 1) * self.warmup_epochs
558
+
559
+ def lr_lambda(step):
560
+ if step < warmup_steps:
561
+ return (step + 1) / (warmup_steps + 1)
562
+ else:
563
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
564
+ return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
565
+
566
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
395
567
 
396
- for epoch in range(self.num_epochs):
568
+ def iterate_minibatches(items: List[Tuple[int, Tuple[str, str]]], batch_size: int):
569
+ for start in range(0, len(items), batch_size):
570
+ yield items[start: start + batch_size]
571
+
572
+ # Training loop with early stopping
573
+ best_val_loss = float('inf')
574
+ best_model_state = None
575
+ patience_counter = 0
576
+ metrics_history = []
577
+ global_step = 0
578
+
579
+ for epoch in tqdm(range(self.num_epochs), desc="Training"):
580
+ # Training phase
581
+ self.cross_attn_head.train()
397
582
  epoch_loss_sum = 0.0
398
- for minibatch in iterate_minibatches(labeled_pairs, self.batch_size):
399
- labels = torch.tensor(
400
- [y for y, _ in minibatch], dtype=torch.float32, device=self.device
401
- )
583
+
584
+ for minibatch in iterate_minibatches(train_pairs, self.batch_size):
585
+ labels = torch.tensor([y for y, _ in minibatch], dtype=torch.float32, device=self.device)
402
586
  string_pairs = [pc for _, pc in minibatch]
403
587
  child_tensor, parent_tensor = self._pairs_as_tensors(string_pairs)
404
588
 
405
- probs = self.cross_attn_head(child_tensor, parent_tensor)
406
- loss = F.binary_cross_entropy(probs, labels)
407
-
408
589
  optimizer.zero_grad()
409
- loss.backward()
410
- optimizer.step()
411
590
 
412
- epoch_loss_sum += float(loss.item()) * len(minibatch)
591
+ # Mixed precision training
592
+ if self.use_amp:
593
+ probs = self.cross_attn_head(child_tensor, parent_tensor)
594
+ loss = F.binary_cross_entropy(probs, labels)
413
595
 
414
- def _score_parent_child_pairs(self, pairs: List[Tuple[str, str]]) -> List[float]:
415
- """Compute probability scores for (parent, child) pairs.
596
+ self.scaler.scale(loss).backward()
416
597
 
417
- Args:
418
- pairs: List of candidate (parent, child) edges to score.
598
+ if self.gradient_clip > 0:
599
+ self.scaler.unscale_(optimizer)
600
+ torch.nn.utils.clip_grad_norm_(
601
+ self.cross_attn_head.parameters(),
602
+ self.gradient_clip
603
+ )
419
604
 
420
- Returns:
421
- List of floats in [0, 1] corresponding to the input order.
605
+ self.scaler.step(optimizer)
606
+ self.scaler.update()
607
+ else:
608
+ probs = self.cross_attn_head(child_tensor, parent_tensor)
609
+ loss = F.binary_cross_entropy(probs, labels)
422
610
 
423
- Raises:
424
- RuntimeError: If the head has not been initialized (call `load()`).
425
- """
426
- if self.cross_attn_head is None:
427
- raise RuntimeError("Head not initialized. Call load().")
611
+ loss.backward()
428
612
 
429
- self.cross_attn_head.eval()
430
- scores: List[float] = []
431
- with torch.no_grad():
432
- for start in range(0, len(pairs), self.batch_size):
433
- chunk = pairs[start : start + self.batch_size]
434
- child_tensor, parent_tensor = self._pairs_as_tensors(chunk)
435
- prob = self.cross_attn_head(child_tensor, parent_tensor)
436
- scores.extend(prob.detach().cpu().tolist())
437
- return scores
613
+ if self.gradient_clip > 0:
614
+ torch.nn.utils.clip_grad_norm_(
615
+ self.cross_attn_head.parameters(),
616
+ self.gradient_clip
617
+ )
618
+
619
+ optimizer.step()
438
620
 
439
- def _extract_parent_child_pairs_and_terms(
440
- self, data: Any
441
- ) -> Tuple[List[Tuple[str, str]], List[str]]:
442
- """Extract (parent, child) edges and the set of unique terms from an ontology-like object.
621
+ if scheduler is not None:
622
+ scheduler.step()
623
+ global_step += 1
624
+
625
+ epoch_loss_sum += float(loss.item()) * len(minibatch)
443
626
 
444
- The function expects `data.type_taxonomies.taxonomies` to be an iterable
445
- of objects with `.parent` and `.child` string-like attributes.
627
+ avg_train_loss = epoch_loss_sum / len(train_pairs)
628
+
629
+ # Validation phase
630
+ self.cross_attn_head.eval()
631
+ val_loss_sum = 0.0
632
+
633
+ with torch.no_grad():
634
+ for minibatch in iterate_minibatches(val_pairs, self.batch_size):
635
+ labels = torch.tensor([y for y, _ in minibatch], dtype=torch.float32, device=self.device)
636
+ string_pairs = [pc for _, pc in minibatch]
637
+ child_tensor, parent_tensor = self._pairs_as_tensors(string_pairs)
638
+
639
+ probs = self.cross_attn_head(child_tensor, parent_tensor)
640
+ loss = F.binary_cross_entropy(probs, labels)
641
+ val_loss_sum += float(loss.item()) * len(minibatch)
642
+
643
+ avg_val_loss = val_loss_sum / len(val_pairs)
644
+
645
+ # Track metrics
646
+ current_lr = optimizer.param_groups[0]['lr']
647
+ metrics = {
648
+ 'epoch': epoch + 1,
649
+ 'train_loss': avg_train_loss,
650
+ 'val_loss': avg_val_loss,
651
+ 'learning_rate': current_lr,
652
+ 'timestamp': datetime.now().isoformat()
653
+ }
654
+ metrics_history.append(metrics)
655
+
656
+ # Save best model
657
+ if avg_val_loss < best_val_loss:
658
+ best_val_loss = avg_val_loss
659
+ best_model_state = self.cross_attn_head.state_dict()
660
+ patience_counter = 0
661
+ print(f"Epoch {epoch + 1}/{self.num_epochs} | "
662
+ f"Train Loss: {avg_train_loss:.4f} | "
663
+ f"Val Loss: {avg_val_loss:.4f} ⭐ (Best) | "
664
+ f"LR: {current_lr:.6f}")
665
+ else:
666
+ patience_counter += 1
667
+ print(f"Epoch {epoch + 1}/{self.num_epochs} | "
668
+ f"Train Loss: {avg_train_loss:.4f} | "
669
+ f"Val Loss: {avg_val_loss:.4f} | "
670
+ f"LR: {current_lr:.6f} | "
671
+ f"Patience: {patience_counter}/{self.patience}")
672
+
673
+ # Early stopping
674
+ if patience_counter >= self.patience:
675
+ print(f"Early stopping triggered at epoch {epoch + 1}")
676
+ break
677
+
678
+ # Restore best model
679
+ if best_model_state is not None:
680
+ self.cross_attn_head.load_state_dict(best_model_state)
681
+ print(f"Restored best model with validation loss: {best_val_loss:.4f}")
682
+
683
+ # Optimize threshold on validation set if requested
684
+ if self.optimize_threshold_on_val and val_pairs:
685
+ self.best_threshold = self._optimize_threshold_on_validation(val_pairs)
686
+ else:
687
+ self.best_threshold = self.prediction_threshold
688
+
689
+ # Save training metrics
690
+ metrics_path = os.path.join(self.output_dir, 'training_metrics.json')
691
+ with open(metrics_path, 'w') as f:
692
+ json.dump(metrics_history, f, indent=2)
693
+
694
+ def _filter_top_k_candidates(self, terms: List[str], top_k: int) -> List[Tuple[str, str]]:
695
+ """
696
+ OPTIONAL OPTIMIZATION (NOT in original Alexbek paper):
697
+ Filter candidate pairs to only include top-k most similar terms based on cosine similarity.
698
+ Memory-efficient chunked implementation.
446
699
 
447
700
  Args:
448
- data: Ontology-like container.
701
+ terms: List of unique terms
702
+ top_k: Number of most similar candidates to keep per term
449
703
 
450
704
  Returns:
451
- A tuple `(pairs, terms)` where:
452
- - `pairs` is a list of (parent, child) strings,
453
- - `terms` is a sorted list of unique term strings (parents ∪ children).
705
+ List of (parent, child) candidate pairs
454
706
  """
707
+ n_terms = len(terms)
708
+
709
+ # Stack all embeddings and normalize for cosine similarity
710
+ all_embeddings = torch.stack([self.term_to_vector[t] for t in terms], dim=0)
711
+ normalized_embeddings = F.normalize(all_embeddings, p=2, dim=1)
712
+
713
+ candidate_pairs = []
714
+
715
+ # Process in chunks to avoid OOM for large taxonomies
716
+ chunk_size = min(1000, n_terms)
717
+
718
+ print("Finding top-k similar terms for each term...")
719
+ for child_start in tqdm(range(0, n_terms, chunk_size), desc="Filtering candidates"):
720
+ child_end = min(child_start + chunk_size, n_terms)
721
+ child_chunk = normalized_embeddings[child_start:child_end]
722
+
723
+ # Compute similarities for this chunk
724
+ similarities = torch.mm(child_chunk, normalized_embeddings.t())
725
+
726
+ # Get top-k+1 for each child in chunk (to exclude self if needed)
727
+ top_k_values, top_k_indices = torch.topk(similarities, min(top_k + 1, n_terms), dim=1)
728
+
729
+ # Add pairs (excluding self-loops)
730
+ for local_idx, child_idx in enumerate(range(child_start, child_end)):
731
+ for parent_idx in top_k_indices[local_idx].cpu().tolist():
732
+ if parent_idx != child_idx:
733
+ candidate_pairs.append((terms[parent_idx], terms[child_idx]))
734
+
735
+ return candidate_pairs
736
+
737
+ def _extract_parent_child_pairs_and_terms(self, data: Any, test: bool) -> Tuple[List[Tuple[str, str]], List[str]]:
738
+ """Extract (parent, child) edges and unique terms from ontology data."""
455
739
  parent_child_pairs: List[Tuple[str, str]] = []
456
740
  unique_terms = set()
741
+
457
742
  for edge in getattr(data, "type_taxonomies").taxonomies:
458
743
  parent, child = str(edge.parent), str(edge.child)
459
- parent_child_pairs.append((parent, child))
744
+ if not test:
745
+ parent_child_pairs.append((parent, child))
460
746
  unique_terms.add(parent)
461
747
  unique_terms.add(child)
748
+
749
+ if test:
750
+ # In test mode, return empty pairs - will score all pairs in _taxonomy_discovery
751
+ pass
752
+
462
753
  return parent_child_pairs, sorted(unique_terms)
463
754
 
464
755
  def _sample_negative_pairs(
465
- self,
466
- positive_pairs: List[Tuple[str, str]],
467
- terms: List[str],
468
- ratio: float = 1.0,
469
- seed: int = 42,
756
+ self,
757
+ positive_pairs: List[Tuple[str, str]],
758
+ terms: List[str],
759
+ ratio: float = 1.0,
760
+ seed: int = 42,
470
761
  ) -> List[Tuple[str, str]]:
471
- """Sample random negative (parent, child) pairs not present in positives.
472
-
473
- Sampling is uniform over the Cartesian product of `terms` excluding
474
- (x, x) self-pairs and any pair found in `positive_pairs`.
475
-
476
- Args:
477
- positive_pairs: Known positive edges to exclude.
478
- terms: Candidate vocabulary (parents ∪ children).
479
- ratio: Number of negatives per positive to draw.
480
- seed: RNG seed used for reproducible sampling.
762
+ """
763
+ Sample negative pairs.
481
764
 
482
- Returns:
483
- A list of sampled negative pairs of approximate length
484
- `int(len(positive_pairs) * ratio)`.
765
+ Original approach: All random negatives (hard_negative_ratio=0.0)
766
+ Optional: Mix of hard negatives and random negatives (hard_negative_ratio>0.0)
485
767
  """
486
768
  random.seed(seed)
487
769
  term_list = list(terms)
488
770
  positive_set = set(positive_pairs)
489
- negatives: List[Tuple[str, str]] = []
490
- target_negative_count = int(len(positive_pairs) * ratio)
491
- while len(negatives) < target_negative_count:
492
- parent = random.choice(term_list)
493
- child = random.choice(term_list)
494
- if parent == child:
495
- continue
496
- candidate = (parent, child)
497
- if candidate in positive_set:
498
- continue
499
- negatives.append(candidate)
771
+
772
+ target_count = int(len(positive_pairs) * ratio)
773
+ hard_count = int(target_count * self.hard_negative_ratio) if self.hard_negative_ratio > 0 else 0
774
+ random_count = target_count - hard_count
775
+
776
+ negatives = []
777
+
778
+ # Hard negatives: pairs with high embedding similarity but not in taxonomy
779
+ if hard_count > 0:
780
+ print(f"Sampling {hard_count} hard negatives based on embedding similarity...")
781
+ all_embeddings = torch.stack([self.term_to_vector[t] for t in term_list])
782
+ normalized_embeddings = F.normalize(all_embeddings, p=2, dim=1)
783
+ similarity_matrix = torch.mm(normalized_embeddings, normalized_embeddings.t())
784
+
785
+ # For each term, get candidates sorted by similarity
786
+ for i in range(len(term_list)):
787
+ if len(negatives) >= hard_count:
788
+ break
789
+
790
+ similarities = similarity_matrix[i]
791
+ sorted_indices = torch.argsort(similarities, descending=True)
792
+
793
+ for j in sorted_indices:
794
+ j_idx = j.item()
795
+ if i == j_idx:
796
+ continue
797
+ candidate = (term_list[j_idx], term_list[i])
798
+ if candidate not in positive_set and candidate not in negatives:
799
+ negatives.append(candidate)
800
+ if len(negatives) >= hard_count:
801
+ break
802
+ # Random negatives (ORIGINAL ALEXBEK APPROACH)
803
+ if random_count > 0:
804
+ print(f"Sampling {random_count} random negatives...")
805
+ max_attempts = random_count * 10
806
+ attempts = 0
807
+
808
+ while len(negatives) < target_count and attempts < max_attempts:
809
+ parent = random.choice(term_list)
810
+ child = random.choice(term_list)
811
+ attempts += 1
812
+
813
+ if parent == child:
814
+ continue
815
+ candidate = (parent, child)
816
+ if candidate not in positive_set and candidate not in negatives:
817
+ negatives.append(candidate)
818
+ if hard_count > 0:
819
+ print(f"Sampled {len(negatives)} negative pairs ({hard_count} hard, {len(negatives) - hard_count} random)")
820
+ else:
821
+ print(f"Sampled {len(negatives)} random negative pairs (original approach)")
500
822
  return negatives