OntoLearner 1.4.6__py3-none-any.whl → 1.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,500 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import math
18
+ import os
19
+ import random
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from sentence_transformers import SentenceTransformer
24
+
25
+ from ...base import AutoLearner
26
+
27
+
28
+ class RMSNorm(nn.Module):
29
+ """Root Mean Square normalization with learnable scale.
30
+
31
+ Computes per-position normalization:
32
+ y = weight * x / sqrt(mean(x^2) + eps)
33
+
34
+ This variant normalizes over the last dimension and keeps scale as a
35
+ learnable parameter, similar to RMSNorm used in modern transformer stacks.
36
+ """
37
+
38
+ def __init__(self, dim: int, eps: float = 1e-6):
39
+ """Initialize the RMSNorm layer.
40
+
41
+ Args:
42
+ dim: Size of the last (feature) dimension to normalize over.
43
+ eps: Small constant added inside the square root for numerical
44
+ stability.
45
+ """
46
+ super().__init__()
47
+ self.eps = eps
48
+ self.weight = nn.Parameter(torch.ones(dim))
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """Apply RMS normalization.
52
+
53
+ Args:
54
+ x: Input tensor of shape (..., dim).
55
+
56
+ Returns:
57
+ Tensor of the same shape as `x`, RMS-normalized over the last axis.
58
+ """
59
+ rms_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
60
+ return self.weight * (x * rms_inv)
61
+
62
+
63
+ class CrossAttentionHead(nn.Module):
64
+ """Minimal multi-head *pair* scorer using cross-attention-style projections.
65
+
66
+ Given child vector `c` and parent vector `p`:
67
+ q = W_q * c, k = W_k * p
68
+ score_head = (q_h · k_h) / sqrt(d_head)
69
+
70
+ We average the per-head scores and apply a sigmoid to produce a probability.
71
+ This is not a full attention block—just a learnable similarity function.
72
+ """
73
+
74
+ def __init__(
75
+ self, hidden_size: int, num_heads: int = 8, rms_norm_eps: float = 1e-6
76
+ ):
77
+ """Initialize projections and per-stream normalizers.
78
+
79
+ Args:
80
+ hidden_size: Dimensionality of input embeddings (child/parent).
81
+ num_heads: Number of subspaces to split the projection into.
82
+ rms_norm_eps: Epsilon for RMSNorm stability.
83
+
84
+ Raises:
85
+ AssertionError: If `hidden_size` is not divisible by `num_heads`.
86
+ """
87
+ super().__init__()
88
+ assert hidden_size % num_heads == 0, (
89
+ "hidden_size must be divisible by num_heads"
90
+ )
91
+ self.hidden_size = hidden_size
92
+ self.num_heads = num_heads
93
+ self.dim_per_head = hidden_size // num_heads
94
+
95
+ # Linear projections for queries (child) and keys (parent)
96
+ self.query_projection = nn.Linear(hidden_size, hidden_size, bias=False)
97
+ self.key_projection = nn.Linear(hidden_size, hidden_size, bias=False)
98
+
99
+ # Pre-projection normalization for stability
100
+ self.query_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
101
+ self.key_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
102
+
103
+ # Xavier init helps stabilize training
104
+ nn.init.xavier_uniform_(self.query_projection.weight)
105
+ nn.init.xavier_uniform_(self.key_projection.weight)
106
+
107
+ def forward(
108
+ self, child_embeddings: torch.Tensor, parent_embeddings: torch.Tensor
109
+ ) -> torch.Tensor:
110
+ """Score (child, parent) pairs.
111
+
112
+ Args:
113
+ child_embeddings: Tensor of shape (batch, hidden_size).
114
+ parent_embeddings: Tensor of shape (batch, hidden_size).
115
+
116
+ Returns:
117
+ Tensor of probabilities with shape (batch,), each in [0, 1].
118
+ """
119
+ batch_size, _ = child_embeddings.shape
120
+
121
+ # Project and normalize
122
+ queries = self.query_norm(self.query_projection(child_embeddings))
123
+ keys = self.key_norm(self.key_projection(parent_embeddings))
124
+
125
+ # Reshape into heads: (batch, heads, dim_per_head)
126
+ queries = queries.view(batch_size, self.num_heads, self.dim_per_head)
127
+ keys = keys.view(batch_size, self.num_heads, self.dim_per_head)
128
+
129
+ # Scaled dot-product similarity per head -> (batch, heads)
130
+ per_head_scores = (queries * keys).sum(-1) / math.sqrt(self.dim_per_head)
131
+
132
+ # Aggregate across heads -> (batch,)
133
+ mean_score = per_head_scores.mean(-1)
134
+
135
+ # Map to probability
136
+ return torch.sigmoid(mean_score)
137
+
138
+
139
+ class AlexbekCrossAttnLearner(AutoLearner):
140
+ """Cross-Attention Taxonomy Learner (inherits AutoLearner).
141
+
142
+ Workflow
143
+ - Encode terms with a SentenceTransformer.
144
+ - Train a compact cross-attention head on (parent, child) pairs
145
+ (positives + sampled negatives) using BCE loss.
146
+ - Inference returns probabilities per pair; edges with prob >= 0.5 are
147
+ labeled as positive.
148
+
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ *,
154
+ embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
155
+ device: str = "cpu",
156
+ num_heads: int = 8,
157
+ lr: float = 5e-5,
158
+ weight_decay: float = 0.01,
159
+ num_epochs: int = 1,
160
+ batch_size: int = 256,
161
+ neg_ratio: float = 1.0, # negatives per positive
162
+ output_dir: str = "./results/",
163
+ seed: int = 42,
164
+ **kwargs: Any,
165
+ ):
166
+ """Configure the learner.
167
+
168
+ Args:
169
+ embedding_model: SentenceTransformer model id/path for term encoding.
170
+ device: 'cuda' or 'cpu'. If 'cuda' is requested but unavailable, CPU
171
+ is used.
172
+ num_heads: Number of heads in the cross-attention scorer.
173
+ lr: Learning rate for AdamW.
174
+ weight_decay: Weight decay for AdamW.
175
+ num_epochs: Number of epochs to train the head.
176
+ batch_size: Minibatch size for training and scoring loops.
177
+ neg_ratio: Number of sampled negatives per positive during training.
178
+ output_dir: Directory to store artifacts (reserved for future use).
179
+ seed: Random seed for reproducibility.
180
+ **kwargs: Passed through to `AutoLearner` base init.
181
+
182
+ Side Effects:
183
+ Creates `output_dir` if missing and seeds Python/Torch RNGs.
184
+ """
185
+ super().__init__(**kwargs)
186
+
187
+ # hyperparameters / settings
188
+ self.embedding_model_id = embedding_model
189
+ self.requested_device = device
190
+ self.num_heads = num_heads
191
+ self.learning_rate = lr
192
+ self.weight_decay = weight_decay
193
+ self.num_epochs = num_epochs
194
+ self.batch_size = batch_size
195
+ self.negative_ratio = neg_ratio
196
+ self.output_dir = output_dir
197
+ self.seed = seed
198
+
199
+ # Prefer requested device but gracefully fall back to CPU
200
+ if torch.cuda.is_available() or self.requested_device == "cpu":
201
+ self.device = torch.device(self.requested_device)
202
+ else:
203
+ self.device = torch.device("cpu")
204
+
205
+ # Will be set in load()
206
+ self.embedder: Optional[SentenceTransformer] = None
207
+ self.cross_attn_head: Optional[CrossAttentionHead] = None
208
+ self.embedding_dim: Optional[int] = None
209
+
210
+ # Cache of term -> embedding tensor (on device)
211
+ self.term_to_vector: Dict[str, torch.Tensor] = {}
212
+
213
+ os.makedirs(self.output_dir, exist_ok=True)
214
+ random.seed(self.seed)
215
+ torch.manual_seed(self.seed)
216
+
217
+ def load(self, **kwargs: Any):
218
+ """Load the sentence embedding model and initialize the cross-attention head.
219
+
220
+ Args:
221
+ **kwargs: Optional override, supports `embedding_model`.
222
+
223
+ Side Effects:
224
+ - Initializes `self.embedder` on the configured device.
225
+ - Probes and stores `self.embedding_dim`.
226
+ - Constructs `self.cross_attn_head` with the probed dimensionality.
227
+ """
228
+ model_id = kwargs.get("embedding_model", self.embedding_model_id)
229
+ self.embedder = SentenceTransformer(
230
+ model_id, trust_remote_code=True, device=str(self.device)
231
+ )
232
+
233
+ # Probe output dimensionality using a dummy encode
234
+ probe_embedding = self.embedder.encode(
235
+ ["_dim_probe_"], convert_to_tensor=True, normalize_embeddings=False
236
+ )
237
+ self.embedding_dim = int(probe_embedding.shape[-1])
238
+
239
+ # Initialize the cross-attention head
240
+ self.cross_attn_head = CrossAttentionHead(
241
+ hidden_size=self.embedding_dim, num_heads=self.num_heads
242
+ ).to(self.device)
243
+
244
+ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
245
+ """Train or infer taxonomy edges according to the AutoLearner contract.
246
+
247
+ Training (`test=False`)
248
+ - Extract positives (parent, child) and the unique term set from `data`.
249
+ - Build/extend the term embedding cache.
250
+ - Sample negatives at ratio `self.negative_ratio`.
251
+ - Train the cross-attention head with BCE loss.
252
+
253
+ Inference (`test=True`)
254
+ - Ensure embeddings exist for all terms.
255
+ - Score candidate pairs and return per-pair probabilities and labels.
256
+
257
+ Args:
258
+ data: Ontology-like object exposing `type_taxonomies.taxonomies`,
259
+ where each item has `.parent` and `.child` string-like fields.
260
+ test: If True, perform inference instead of training.
261
+
262
+ Returns:
263
+ - `None` on training.
264
+ - On inference: List of dicts
265
+ `{"parent": str, "child": str, "score": float, "label": int}`.
266
+ """
267
+ if self.embedder is None or self.cross_attn_head is None:
268
+ self.load()
269
+
270
+ if not test:
271
+ positive_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(
272
+ data
273
+ )
274
+ self._ensure_term_embeddings(unique_terms)
275
+ negative_pairs = self._sample_negative_pairs(
276
+ positive_pairs, unique_terms, ratio=self.negative_ratio, seed=self.seed
277
+ )
278
+ self._train_cross_attn_head(positive_pairs, negative_pairs)
279
+ return None
280
+ else:
281
+ candidate_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(
282
+ data
283
+ )
284
+ self._ensure_term_embeddings(unique_terms, append_only=True)
285
+ probabilities = self._score_parent_child_pairs(candidate_pairs)
286
+
287
+ predictions = [
288
+ {
289
+ "parent": parent,
290
+ "child": child,
291
+ "score": float(prob),
292
+ "label": int(prob >= 0.5),
293
+ }
294
+ for (parent, child), prob in zip(candidate_pairs, probabilities)
295
+ ]
296
+ return predictions
297
+
298
+ def _ensure_term_embeddings(
299
+ self, terms: List[str], append_only: bool = False
300
+ ) -> None:
301
+ """Encode terms with the sentence embedder and store in cache.
302
+
303
+ Args:
304
+ terms: List of unique term strings to embed.
305
+ append_only: If True, only embed terms missing from the cache;
306
+ otherwise (re)encode all provided terms.
307
+
308
+ Raises:
309
+ RuntimeError: If called before `load()`.
310
+ """
311
+ if self.embedder is None:
312
+ raise RuntimeError("Call load() before building term embeddings")
313
+
314
+ terms_to_encode = (
315
+ [t for t in terms if t not in self.term_to_vector] if append_only else terms
316
+ )
317
+ if not terms_to_encode:
318
+ return
319
+
320
+ embeddings = self.embedder.encode(
321
+ terms_to_encode,
322
+ convert_to_tensor=True,
323
+ normalize_embeddings=False,
324
+ batch_size=256,
325
+ show_progress_bar=False,
326
+ )
327
+ for term, embedding in zip(terms_to_encode, embeddings):
328
+ self.term_to_vector[term] = embedding.detach().to(self.device)
329
+
330
+ def _pairs_as_tensors(
331
+ self, pairs: List[Tuple[str, str]]
332
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
333
+ """Convert string pairs into aligned embedding tensors on the correct device.
334
+
335
+ Args:
336
+ pairs: List of (parent, child) term strings.
337
+
338
+ Returns:
339
+ Tuple `(child_tensor, parent_tensor)` where each tensor has shape
340
+ `(batch, embedding_dim)` and is located on `self.device`.
341
+
342
+ Notes:
343
+ This function assumes that all terms in `pairs` are present in
344
+ `self.term_to_vector`. Use `_ensure_term_embeddings` beforehand.
345
+ """
346
+ # child embeddings tensor of shape (batch, dim)
347
+ child_tensor = torch.stack(
348
+ [self.term_to_vector[child] for (_, child) in pairs], dim=0
349
+ ).to(self.device)
350
+ # parent embeddings tensor of shape (batch, dim)
351
+ parent_tensor = torch.stack(
352
+ [self.term_to_vector[parent] for (parent, _) in pairs], dim=0
353
+ ).to(self.device)
354
+ return child_tensor, parent_tensor
355
+
356
+ def _train_cross_attn_head(
357
+ self,
358
+ positive_pairs: List[Tuple[str, str]],
359
+ negative_pairs: List[Tuple[str, str]],
360
+ ) -> None:
361
+ """Train the cross-attention head with BCE loss on labeled pairs.
362
+
363
+ The dataset is a concatenation of positives (label 1) and sampled
364
+ negatives (label 0). The head is optimized with AdamW.
365
+
366
+ Args:
367
+ positive_pairs: List of ground-truth (parent, child) edges.
368
+ negative_pairs: List of sampled non-edges.
369
+
370
+ Raises:
371
+ RuntimeError: If the head has not been initialized (call `load()`).
372
+ """
373
+ if self.cross_attn_head is None:
374
+ raise RuntimeError("Head not initialized. Call load().")
375
+
376
+ self.cross_attn_head.train()
377
+ optimizer = torch.optim.AdamW(
378
+ self.cross_attn_head.parameters(),
379
+ lr=self.learning_rate,
380
+ weight_decay=self.weight_decay,
381
+ )
382
+
383
+ # Build a simple supervised dataset: 1 for positive, 0 for negative
384
+ labeled_pairs: List[Tuple[int, Tuple[str, str]]] = [
385
+ (1, pc) for pc in positive_pairs
386
+ ] + [(0, nc) for nc in negative_pairs]
387
+ random.shuffle(labeled_pairs)
388
+
389
+ def iterate_minibatches(
390
+ items: List[Tuple[int, Tuple[str, str]]], batch_size: int
391
+ ):
392
+ """Yield contiguous minibatches of size `batch_size` from `items`."""
393
+ for start in range(0, len(items), batch_size):
394
+ yield items[start : start + batch_size]
395
+
396
+ for epoch in range(self.num_epochs):
397
+ epoch_loss_sum = 0.0
398
+ for minibatch in iterate_minibatches(labeled_pairs, self.batch_size):
399
+ labels = torch.tensor(
400
+ [y for y, _ in minibatch], dtype=torch.float32, device=self.device
401
+ )
402
+ string_pairs = [pc for _, pc in minibatch]
403
+ child_tensor, parent_tensor = self._pairs_as_tensors(string_pairs)
404
+
405
+ probs = self.cross_attn_head(child_tensor, parent_tensor)
406
+ loss = F.binary_cross_entropy(probs, labels)
407
+
408
+ optimizer.zero_grad()
409
+ loss.backward()
410
+ optimizer.step()
411
+
412
+ epoch_loss_sum += float(loss.item()) * len(minibatch)
413
+
414
+ def _score_parent_child_pairs(self, pairs: List[Tuple[str, str]]) -> List[float]:
415
+ """Compute probability scores for (parent, child) pairs.
416
+
417
+ Args:
418
+ pairs: List of candidate (parent, child) edges to score.
419
+
420
+ Returns:
421
+ List of floats in [0, 1] corresponding to the input order.
422
+
423
+ Raises:
424
+ RuntimeError: If the head has not been initialized (call `load()`).
425
+ """
426
+ if self.cross_attn_head is None:
427
+ raise RuntimeError("Head not initialized. Call load().")
428
+
429
+ self.cross_attn_head.eval()
430
+ scores: List[float] = []
431
+ with torch.no_grad():
432
+ for start in range(0, len(pairs), self.batch_size):
433
+ chunk = pairs[start : start + self.batch_size]
434
+ child_tensor, parent_tensor = self._pairs_as_tensors(chunk)
435
+ prob = self.cross_attn_head(child_tensor, parent_tensor)
436
+ scores.extend(prob.detach().cpu().tolist())
437
+ return scores
438
+
439
+ def _extract_parent_child_pairs_and_terms(
440
+ self, data: Any
441
+ ) -> Tuple[List[Tuple[str, str]], List[str]]:
442
+ """Extract (parent, child) edges and the set of unique terms from an ontology-like object.
443
+
444
+ The function expects `data.type_taxonomies.taxonomies` to be an iterable
445
+ of objects with `.parent` and `.child` string-like attributes.
446
+
447
+ Args:
448
+ data: Ontology-like container.
449
+
450
+ Returns:
451
+ A tuple `(pairs, terms)` where:
452
+ - `pairs` is a list of (parent, child) strings,
453
+ - `terms` is a sorted list of unique term strings (parents ∪ children).
454
+ """
455
+ parent_child_pairs: List[Tuple[str, str]] = []
456
+ unique_terms = set()
457
+ for edge in getattr(data, "type_taxonomies").taxonomies:
458
+ parent, child = str(edge.parent), str(edge.child)
459
+ parent_child_pairs.append((parent, child))
460
+ unique_terms.add(parent)
461
+ unique_terms.add(child)
462
+ return parent_child_pairs, sorted(unique_terms)
463
+
464
+ def _sample_negative_pairs(
465
+ self,
466
+ positive_pairs: List[Tuple[str, str]],
467
+ terms: List[str],
468
+ ratio: float = 1.0,
469
+ seed: int = 42,
470
+ ) -> List[Tuple[str, str]]:
471
+ """Sample random negative (parent, child) pairs not present in positives.
472
+
473
+ Sampling is uniform over the Cartesian product of `terms` excluding
474
+ (x, x) self-pairs and any pair found in `positive_pairs`.
475
+
476
+ Args:
477
+ positive_pairs: Known positive edges to exclude.
478
+ terms: Candidate vocabulary (parents ∪ children).
479
+ ratio: Number of negatives per positive to draw.
480
+ seed: RNG seed used for reproducible sampling.
481
+
482
+ Returns:
483
+ A list of sampled negative pairs of approximate length
484
+ `int(len(positive_pairs) * ratio)`.
485
+ """
486
+ random.seed(seed)
487
+ term_list = list(terms)
488
+ positive_set = set(positive_pairs)
489
+ negatives: List[Tuple[str, str]] = []
490
+ target_negative_count = int(len(positive_pairs) * ratio)
491
+ while len(negatives) < target_negative_count:
492
+ parent = random.choice(term_list)
493
+ child = random.choice(term_list)
494
+ if parent == child:
495
+ continue
496
+ candidate = (parent, child)
497
+ if candidate in positive_set:
498
+ continue
499
+ negatives.append(candidate)
500
+ return negatives