OntoLearner 1.4.11__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +4 -2
- ontolearner/learner/__init__.py +2 -1
- ontolearner/learner/label_mapper.py +4 -3
- ontolearner/learner/llm.py +257 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +632 -310
- ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
- ontolearner/ontology/biology.py +2 -3
- ontolearner/ontology/chemistry.py +16 -18
- ontolearner/ontology/ecology_environment.py +2 -3
- ontolearner/ontology/general.py +4 -6
- ontolearner/ontology/material_science_engineering.py +64 -45
- ontolearner/ontology/medicine.py +2 -3
- ontolearner/ontology/scholarly_knowledge.py +6 -9
- ontolearner/processor.py +3 -3
- {ontolearner-1.4.11.dist-info → ontolearner-1.5.0.dist-info}/METADATA +1 -1
- {ontolearner-1.4.11.dist-info → ontolearner-1.5.0.dist-info}/RECORD +19 -19
- {ontolearner-1.4.11.dist-info → ontolearner-1.5.0.dist-info}/WHEEL +1 -1
- {ontolearner-1.4.11.dist-info → ontolearner-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,365 +11,524 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
14
|
from typing import Any, Dict, List, Optional, Tuple
|
|
16
|
-
|
|
15
|
+
import json
|
|
17
16
|
import math
|
|
18
17
|
import os
|
|
19
18
|
import random
|
|
19
|
+
from datetime import datetime
|
|
20
20
|
import torch
|
|
21
21
|
import torch.nn as nn
|
|
22
22
|
import torch.nn.functional as F
|
|
23
23
|
from sentence_transformers import SentenceTransformer
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
from torch.cuda.amp import GradScaler
|
|
24
26
|
|
|
25
27
|
from ...base import AutoLearner
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
class RMSNorm(nn.Module):
|
|
29
|
-
"""Root Mean Square normalization with learnable scale.
|
|
30
|
-
|
|
31
|
-
Computes per-position normalization:
|
|
32
|
-
y = weight * x / sqrt(mean(x^2) + eps)
|
|
33
|
-
|
|
34
|
-
This variant normalizes over the last dimension and keeps scale as a
|
|
35
|
-
learnable parameter, similar to RMSNorm used in modern transformer stacks.
|
|
36
|
-
"""
|
|
37
|
-
|
|
31
|
+
"""Root Mean Square normalization with learnable scale."""
|
|
38
32
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
39
|
-
"""Initialize the RMSNorm layer.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
dim: Size of the last (feature) dimension to normalize over.
|
|
43
|
-
eps: Small constant added inside the square root for numerical
|
|
44
|
-
stability.
|
|
45
|
-
"""
|
|
46
33
|
super().__init__()
|
|
47
34
|
self.eps = eps
|
|
48
35
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
49
36
|
|
|
50
37
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
-
"""Apply RMS normalization.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
x: Input tensor of shape (..., dim).
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
Tensor of the same shape as `x`, RMS-normalized over the last axis.
|
|
58
|
-
"""
|
|
59
38
|
rms_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
60
39
|
return self.weight * (x * rms_inv)
|
|
61
40
|
|
|
62
|
-
|
|
63
41
|
class CrossAttentionHead(nn.Module):
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
Given child vector `c` and parent vector `p`:
|
|
67
|
-
q = W_q * c, k = W_k * p
|
|
68
|
-
score_head = (q_h · k_h) / sqrt(d_head)
|
|
69
|
-
|
|
70
|
-
We average the per-head scores and apply a sigmoid to produce a probability.
|
|
71
|
-
This is not a full attention block—just a learnable similarity function.
|
|
72
|
-
"""
|
|
73
|
-
|
|
74
|
-
def __init__(
|
|
75
|
-
self, hidden_size: int, num_heads: int = 8, rms_norm_eps: float = 1e-6
|
|
76
|
-
):
|
|
77
|
-
"""Initialize projections and per-stream normalizers.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
hidden_size: Dimensionality of input embeddings (child/parent).
|
|
81
|
-
num_heads: Number of subspaces to split the projection into.
|
|
82
|
-
rms_norm_eps: Epsilon for RMSNorm stability.
|
|
83
|
-
|
|
84
|
-
Raises:
|
|
85
|
-
AssertionError: If `hidden_size` is not divisible by `num_heads`.
|
|
86
|
-
"""
|
|
42
|
+
"""Efficient multi-head cross-attention scorer for parent-child pairs."""
|
|
43
|
+
def __init__(self, hidden_size: int, num_heads: int = 8, rms_norm_eps: float = 1e-6, dropout: float = 0.1):
|
|
87
44
|
super().__init__()
|
|
88
|
-
assert hidden_size % num_heads == 0,
|
|
89
|
-
"hidden_size must be divisible by num_heads"
|
|
90
|
-
)
|
|
45
|
+
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
|
|
91
46
|
self.hidden_size = hidden_size
|
|
92
47
|
self.num_heads = num_heads
|
|
93
48
|
self.dim_per_head = hidden_size // num_heads
|
|
94
49
|
|
|
95
|
-
# Linear projections for queries (child) and keys (parent)
|
|
96
50
|
self.query_projection = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
97
51
|
self.key_projection = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
98
52
|
|
|
99
|
-
# Pre-projection normalization for stability
|
|
100
53
|
self.query_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
101
54
|
self.key_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
102
55
|
|
|
103
|
-
|
|
56
|
+
self.dropout = nn.Dropout(dropout)
|
|
57
|
+
|
|
104
58
|
nn.init.xavier_uniform_(self.query_projection.weight)
|
|
105
59
|
nn.init.xavier_uniform_(self.key_projection.weight)
|
|
106
60
|
|
|
107
|
-
def forward(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
"""Score (child, parent) pairs.
|
|
61
|
+
def forward(self, child_embeddings: torch.Tensor, parent_embeddings: torch.Tensor) -> torch.Tensor:
|
|
62
|
+
"""
|
|
63
|
+
Score (child, parent) pairs efficiently.
|
|
111
64
|
|
|
112
65
|
Args:
|
|
113
|
-
child_embeddings:
|
|
114
|
-
parent_embeddings:
|
|
66
|
+
child_embeddings: (batch_child, hidden_size) or (1, n_terms, hidden_size) for broadcasting
|
|
67
|
+
parent_embeddings: (batch_parent, hidden_size) or (1, n_terms, hidden_size) for broadcasting
|
|
115
68
|
|
|
116
69
|
Returns:
|
|
117
|
-
|
|
70
|
+
scores: (batch_child, batch_parent) if both are 2D, or appropriate broadcast shape
|
|
118
71
|
"""
|
|
119
|
-
|
|
72
|
+
# Handle 2D input (standard batch processing)
|
|
73
|
+
if child_embeddings.dim() == 2 and parent_embeddings.dim() == 2:
|
|
74
|
+
batch_size = child_embeddings.shape[0]
|
|
75
|
+
queries = self.query_norm(self.query_projection(child_embeddings))
|
|
76
|
+
keys = self.key_norm(self.key_projection(parent_embeddings))
|
|
120
77
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
keys = self.key_norm(self.key_projection(parent_embeddings))
|
|
78
|
+
queries = self.dropout(queries)
|
|
79
|
+
keys = self.dropout(keys)
|
|
124
80
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
keys = keys.view(batch_size, self.num_heads, self.dim_per_head)
|
|
81
|
+
queries = queries.view(batch_size, self.num_heads, self.dim_per_head)
|
|
82
|
+
keys = keys.view(batch_size, self.num_heads, self.dim_per_head)
|
|
128
83
|
|
|
129
|
-
|
|
130
|
-
|
|
84
|
+
per_head_scores = (queries * keys).sum(-1) / math.sqrt(self.dim_per_head)
|
|
85
|
+
mean_score = per_head_scores.mean(-1)
|
|
86
|
+
return torch.sigmoid(mean_score)
|
|
131
87
|
|
|
132
|
-
#
|
|
133
|
-
|
|
88
|
+
# Handle 3D input for efficient matrix computation
|
|
89
|
+
elif child_embeddings.dim() == 3 and parent_embeddings.dim() == 3:
|
|
90
|
+
n_child = child_embeddings.shape[1]
|
|
91
|
+
n_parent = parent_embeddings.shape[1]
|
|
134
92
|
|
|
135
|
-
|
|
136
|
-
|
|
93
|
+
queries = self.query_norm(self.query_projection(child_embeddings))
|
|
94
|
+
keys = self.key_norm(self.key_projection(parent_embeddings))
|
|
137
95
|
|
|
96
|
+
queries = self.dropout(queries)
|
|
97
|
+
keys = self.dropout(keys)
|
|
138
98
|
|
|
139
|
-
|
|
140
|
-
|
|
99
|
+
queries = queries.view(1, n_child, self.num_heads, self.dim_per_head)
|
|
100
|
+
keys = keys.view(1, n_parent, self.num_heads, self.dim_per_head)
|
|
141
101
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
- Train a compact cross-attention head on (parent, child) pairs
|
|
145
|
-
(positives + sampled negatives) using BCE loss.
|
|
146
|
-
- Inference returns probabilities per pair; edges with prob >= 0.5 are
|
|
147
|
-
labeled as positive.
|
|
102
|
+
queries = queries.squeeze(0).transpose(1, 2)
|
|
103
|
+
keys = keys.squeeze(0).transpose(1, 2)
|
|
148
104
|
|
|
105
|
+
per_head_scores = torch.einsum('chd,phd->cph', queries, keys) / math.sqrt(self.dim_per_head)
|
|
106
|
+
mean_score = per_head_scores.mean(-1)
|
|
107
|
+
return torch.sigmoid(mean_score)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class AlexbekCrossAttnLearner(AutoLearner):
|
|
149
111
|
"""
|
|
112
|
+
Cross-Attention Taxonomy Learner - faithful reproduction of Alexbek's approach.
|
|
150
113
|
|
|
114
|
+
This implementation follows the original paper's methodology:
|
|
115
|
+
- Computes full NxN pairwise scores for all term pairs
|
|
116
|
+
- Uses threshold-based prediction (0.5 default, or F1-optimized on validation)
|
|
117
|
+
- No candidate pre-filtering (can be optionally enabled for large taxonomies)
|
|
118
|
+
"""
|
|
151
119
|
def __init__(
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
120
|
+
self,
|
|
121
|
+
*,
|
|
122
|
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
123
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
124
|
+
num_heads: int = 8,
|
|
125
|
+
dropout: float = 0.1,
|
|
126
|
+
lr: float = 1e-4,
|
|
127
|
+
weight_decay: float = 0.01,
|
|
128
|
+
num_epochs: int = 10,
|
|
129
|
+
batch_size: int = 256,
|
|
130
|
+
inference_batch_size: int = 512,
|
|
131
|
+
neg_ratio: float = 1.0,
|
|
132
|
+
top_k_candidates: Optional[int] = None, # None = original behavior (all pairs)
|
|
133
|
+
output_dir: str = "./results/",
|
|
134
|
+
seed: int = 42,
|
|
135
|
+
cache_embeddings: bool = True,
|
|
136
|
+
use_lr_scheduler: bool = True,
|
|
137
|
+
warmup_epochs: int = 1,
|
|
138
|
+
gradient_clip: float = 1.0,
|
|
139
|
+
use_amp: bool = True,
|
|
140
|
+
hard_negative_ratio: float = 0.0, # 0.0 = original (all random negatives)
|
|
141
|
+
patience: int = 3,
|
|
142
|
+
validation_split: float = 0.1,
|
|
143
|
+
normalize_embeddings: bool = True,
|
|
144
|
+
prediction_threshold: float = 0.5, # Original uses 0.5 or F1-optimized
|
|
145
|
+
optimize_threshold_on_val: bool = True, # Set True to replicate "Validation-F1" approach
|
|
146
|
+
**kwargs: Any,
|
|
165
147
|
):
|
|
166
|
-
"""Configure the learner.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
embedding_model: SentenceTransformer model id/path for term encoding.
|
|
170
|
-
device: 'cuda' or 'cpu'. If 'cuda' is requested but unavailable, CPU
|
|
171
|
-
is used.
|
|
172
|
-
num_heads: Number of heads in the cross-attention scorer.
|
|
173
|
-
lr: Learning rate for AdamW.
|
|
174
|
-
weight_decay: Weight decay for AdamW.
|
|
175
|
-
num_epochs: Number of epochs to train the head.
|
|
176
|
-
batch_size: Minibatch size for training and scoring loops.
|
|
177
|
-
neg_ratio: Number of sampled negatives per positive during training.
|
|
178
|
-
output_dir: Directory to store artifacts (reserved for future use).
|
|
179
|
-
seed: Random seed for reproducibility.
|
|
180
|
-
**kwargs: Passed through to `AutoLearner` base init.
|
|
181
|
-
|
|
182
|
-
Side Effects:
|
|
183
|
-
Creates `output_dir` if missing and seeds Python/Torch RNGs.
|
|
184
|
-
"""
|
|
185
148
|
super().__init__(**kwargs)
|
|
186
149
|
|
|
187
|
-
# hyperparameters / settings
|
|
188
150
|
self.embedding_model_id = embedding_model
|
|
189
151
|
self.requested_device = device
|
|
190
152
|
self.num_heads = num_heads
|
|
153
|
+
self.dropout = dropout
|
|
191
154
|
self.learning_rate = lr
|
|
192
155
|
self.weight_decay = weight_decay
|
|
193
156
|
self.num_epochs = num_epochs
|
|
194
157
|
self.batch_size = batch_size
|
|
158
|
+
self.inference_batch_size = inference_batch_size
|
|
195
159
|
self.negative_ratio = neg_ratio
|
|
160
|
+
self.top_k_candidates = top_k_candidates
|
|
196
161
|
self.output_dir = output_dir
|
|
197
162
|
self.seed = seed
|
|
163
|
+
self.cache_embeddings = cache_embeddings
|
|
164
|
+
self.use_lr_scheduler = use_lr_scheduler
|
|
165
|
+
self.warmup_epochs = warmup_epochs
|
|
166
|
+
self.gradient_clip = gradient_clip
|
|
167
|
+
self.use_amp = use_amp and torch.cuda.is_available()
|
|
168
|
+
self.hard_negative_ratio = hard_negative_ratio
|
|
169
|
+
self.patience = patience
|
|
170
|
+
self.validation_split = validation_split
|
|
171
|
+
self.normalize_embeddings = normalize_embeddings
|
|
172
|
+
self.prediction_threshold = prediction_threshold
|
|
173
|
+
self.optimize_threshold_on_val = optimize_threshold_on_val
|
|
198
174
|
|
|
199
|
-
# Prefer requested device but gracefully fall back to CPU
|
|
200
175
|
if torch.cuda.is_available() or self.requested_device == "cpu":
|
|
201
176
|
self.device = torch.device(self.requested_device)
|
|
202
177
|
else:
|
|
203
178
|
self.device = torch.device("cpu")
|
|
204
179
|
|
|
205
|
-
# Will be set in load()
|
|
206
180
|
self.embedder: Optional[SentenceTransformer] = None
|
|
207
181
|
self.cross_attn_head: Optional[CrossAttentionHead] = None
|
|
208
182
|
self.embedding_dim: Optional[int] = None
|
|
209
|
-
|
|
210
|
-
# Cache of term -> embedding tensor (on device)
|
|
211
183
|
self.term_to_vector: Dict[str, torch.Tensor] = {}
|
|
184
|
+
self.scaler: Optional[GradScaler] = GradScaler() if self.use_amp else None
|
|
185
|
+
self.best_threshold: float = self.prediction_threshold
|
|
212
186
|
|
|
213
187
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
214
188
|
random.seed(self.seed)
|
|
215
189
|
torch.manual_seed(self.seed)
|
|
190
|
+
if torch.cuda.is_available():
|
|
191
|
+
torch.cuda.manual_seed_all(self.seed)
|
|
216
192
|
|
|
217
193
|
def load(self, **kwargs: Any):
|
|
218
|
-
"""Load the sentence embedding model and initialize the cross-attention head.
|
|
219
|
-
|
|
220
|
-
Args:
|
|
221
|
-
**kwargs: Optional override, supports `embedding_model`.
|
|
222
|
-
|
|
223
|
-
Side Effects:
|
|
224
|
-
- Initializes `self.embedder` on the configured device.
|
|
225
|
-
- Probes and stores `self.embedding_dim`.
|
|
226
|
-
- Constructs `self.cross_attn_head` with the probed dimensionality.
|
|
227
|
-
"""
|
|
194
|
+
"""Load the sentence embedding model and initialize the cross-attention head."""
|
|
228
195
|
model_id = kwargs.get("embedding_model", self.embedding_model_id)
|
|
229
|
-
self.embedder = SentenceTransformer(
|
|
230
|
-
model_id, trust_remote_code=True, device=str(self.device)
|
|
231
|
-
)
|
|
196
|
+
self.embedder = SentenceTransformer(model_id, trust_remote_code=True, device=str(self.device))
|
|
232
197
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
)
|
|
198
|
+
probe_embedding = self.embedder.encode(["_dim_probe_"],
|
|
199
|
+
convert_to_tensor=True,
|
|
200
|
+
normalize_embeddings=False)
|
|
237
201
|
self.embedding_dim = int(probe_embedding.shape[-1])
|
|
238
202
|
|
|
239
|
-
# Initialize the cross-attention head
|
|
240
203
|
self.cross_attn_head = CrossAttentionHead(
|
|
241
|
-
hidden_size=self.embedding_dim,
|
|
204
|
+
hidden_size=self.embedding_dim,
|
|
205
|
+
num_heads=self.num_heads,
|
|
206
|
+
dropout=self.dropout
|
|
242
207
|
).to(self.device)
|
|
243
208
|
|
|
244
|
-
def
|
|
245
|
-
"""
|
|
209
|
+
def save_model(self, path: str) -> None:
|
|
210
|
+
"""Save the trained cross-attention head."""
|
|
211
|
+
if self.cross_attn_head is None:
|
|
212
|
+
raise RuntimeError("No model to save")
|
|
213
|
+
|
|
214
|
+
checkpoint = {
|
|
215
|
+
'model_state_dict': self.cross_attn_head.state_dict(),
|
|
216
|
+
'embedding_dim': self.embedding_dim,
|
|
217
|
+
'num_heads': self.num_heads,
|
|
218
|
+
'dropout': self.dropout,
|
|
219
|
+
'embedding_model_id': self.embedding_model_id,
|
|
220
|
+
'best_threshold': self.best_threshold,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
torch.save(checkpoint, path)
|
|
224
|
+
print(f"Model saved to {path}")
|
|
225
|
+
|
|
226
|
+
def load_model(self, path: str) -> None:
|
|
227
|
+
"""Load a trained cross-attention head."""
|
|
228
|
+
checkpoint = torch.load(path, map_location=self.device)
|
|
229
|
+
|
|
230
|
+
self.embedding_dim = checkpoint['embedding_dim']
|
|
231
|
+
self.num_heads = checkpoint['num_heads']
|
|
232
|
+
self.dropout = checkpoint.get('dropout', 0.1)
|
|
233
|
+
self.embedding_model_id = checkpoint.get('embedding_model_id', self.embedding_model_id)
|
|
234
|
+
self.best_threshold = checkpoint.get('best_threshold', 0.5)
|
|
235
|
+
|
|
236
|
+
# Load embedder if not already loaded
|
|
237
|
+
if self.embedder is None:
|
|
238
|
+
self.embedder = SentenceTransformer(
|
|
239
|
+
self.embedding_model_id,
|
|
240
|
+
trust_remote_code=True,
|
|
241
|
+
device=str(self.device)
|
|
242
|
+
)
|
|
246
243
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
244
|
+
self.cross_attn_head = CrossAttentionHead(
|
|
245
|
+
hidden_size=self.embedding_dim,
|
|
246
|
+
num_heads=self.num_heads,
|
|
247
|
+
dropout=self.dropout
|
|
248
|
+
).to(self.device)
|
|
252
249
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
250
|
+
self.cross_attn_head.load_state_dict(checkpoint['model_state_dict'])
|
|
251
|
+
print(f"Model loaded from {path} (threshold: {self.best_threshold:.3f})")
|
|
252
|
+
|
|
253
|
+
def save_config(self, path: str) -> None:
|
|
254
|
+
"""Save hyperparameters to JSON."""
|
|
255
|
+
config = {
|
|
256
|
+
'embedding_model': self.embedding_model_id,
|
|
257
|
+
'num_heads': self.num_heads,
|
|
258
|
+
'dropout': self.dropout,
|
|
259
|
+
'lr': self.learning_rate,
|
|
260
|
+
'weight_decay': self.weight_decay,
|
|
261
|
+
'num_epochs': self.num_epochs,
|
|
262
|
+
'batch_size': self.batch_size,
|
|
263
|
+
'inference_batch_size': self.inference_batch_size,
|
|
264
|
+
'negative_ratio': self.negative_ratio,
|
|
265
|
+
'top_k_candidates': self.top_k_candidates,
|
|
266
|
+
'use_lr_scheduler': self.use_lr_scheduler,
|
|
267
|
+
'warmup_epochs': self.warmup_epochs,
|
|
268
|
+
'gradient_clip': self.gradient_clip,
|
|
269
|
+
'use_amp': self.use_amp,
|
|
270
|
+
'hard_negative_ratio': self.hard_negative_ratio,
|
|
271
|
+
'patience': self.patience,
|
|
272
|
+
'validation_split': self.validation_split,
|
|
273
|
+
'normalize_embeddings': self.normalize_embeddings,
|
|
274
|
+
'prediction_threshold': self.prediction_threshold,
|
|
275
|
+
'optimize_threshold_on_val': self.optimize_threshold_on_val,
|
|
276
|
+
'best_threshold': self.best_threshold,
|
|
277
|
+
'seed': self.seed,
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
with open(path, 'w') as f:
|
|
281
|
+
json.dump(config, f, indent=2)
|
|
282
|
+
print(f"Configuration saved to {path}")
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def from_config(cls, config_path: str, **override_kwargs):
|
|
286
|
+
"""Load from configuration file."""
|
|
287
|
+
with open(config_path, 'r') as f:
|
|
288
|
+
config = json.load(f)
|
|
289
|
+
|
|
290
|
+
config.update(override_kwargs)
|
|
291
|
+
return cls(**config)
|
|
256
292
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
test: If True, perform inference instead of training.
|
|
293
|
+
def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
294
|
+
"""
|
|
295
|
+
Train or infer taxonomy edges.
|
|
261
296
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
- On inference: List of dicts
|
|
265
|
-
`{"parent": str, "child": str, "score": float, "label": int}`.
|
|
297
|
+
Original behavior: Scores ALL possible pairs and applies threshold.
|
|
298
|
+
Optional optimization: Can pre-filter to top-k candidates if top_k_candidates is set.
|
|
266
299
|
"""
|
|
267
300
|
if self.embedder is None or self.cross_attn_head is None:
|
|
268
301
|
self.load()
|
|
269
302
|
|
|
270
303
|
if not test:
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
)
|
|
304
|
+
# Training mode
|
|
305
|
+
positive_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(data, test=test)
|
|
274
306
|
self._ensure_term_embeddings(unique_terms)
|
|
275
|
-
negative_pairs = self._sample_negative_pairs(
|
|
276
|
-
|
|
277
|
-
|
|
307
|
+
negative_pairs = self._sample_negative_pairs(positive_pairs,
|
|
308
|
+
unique_terms,
|
|
309
|
+
ratio=self.negative_ratio,
|
|
310
|
+
seed=self.seed)
|
|
278
311
|
self._train_cross_attn_head(positive_pairs, negative_pairs)
|
|
312
|
+
|
|
313
|
+
# Save model and config after training
|
|
314
|
+
model_path = os.path.join(self.output_dir, "best_model.pt")
|
|
315
|
+
config_path = os.path.join(self.output_dir, "config.json")
|
|
316
|
+
self.save_model(model_path)
|
|
317
|
+
self.save_config(config_path)
|
|
318
|
+
|
|
279
319
|
return None
|
|
280
320
|
else:
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
)
|
|
321
|
+
# Inference mode
|
|
322
|
+
candidate_pairs, unique_terms = self._extract_parent_child_pairs_and_terms(data, test=test)
|
|
284
323
|
self._ensure_term_embeddings(unique_terms, append_only=True)
|
|
285
|
-
probabilities = self._score_parent_child_pairs(candidate_pairs)
|
|
286
|
-
|
|
287
|
-
predictions = [
|
|
288
|
-
{
|
|
289
|
-
"parent": parent,
|
|
290
|
-
"child": child,
|
|
291
|
-
"score": float(prob),
|
|
292
|
-
"label": int(prob >= 0.5),
|
|
293
|
-
}
|
|
294
|
-
for (parent, child), prob in zip(candidate_pairs, probabilities)
|
|
295
|
-
]
|
|
296
|
-
return predictions
|
|
297
324
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
325
|
+
# Original approach: score all pairs
|
|
326
|
+
if self.top_k_candidates is None:
|
|
327
|
+
print(f"ORIGINAL MODE: Computing full {len(unique_terms)}x{len(unique_terms)} probability matrix...")
|
|
328
|
+
probabilities = self._score_all_pairs_efficient(unique_terms)
|
|
329
|
+
|
|
330
|
+
# Apply threshold to get predictions
|
|
331
|
+
print(f"Applying threshold {self.best_threshold:.3f} to extract predictions...")
|
|
332
|
+
binary_matrix = (probabilities >= self.best_threshold).cpu().numpy()
|
|
333
|
+
|
|
334
|
+
# Find indices where prediction is 1
|
|
335
|
+
child_indices, parent_indices = binary_matrix.nonzero()
|
|
336
|
+
|
|
337
|
+
# Get corresponding probabilities
|
|
338
|
+
probs = probabilities[child_indices, parent_indices].cpu().numpy()
|
|
339
|
+
|
|
340
|
+
# Build predictions
|
|
341
|
+
predictions = [
|
|
342
|
+
{
|
|
343
|
+
"parent": unique_terms[parent_idx],
|
|
344
|
+
"child": unique_terms[child_idx],
|
|
345
|
+
"score": float(prob),
|
|
346
|
+
"label": 1,
|
|
347
|
+
}
|
|
348
|
+
for child_idx, parent_idx, prob in zip(child_indices, parent_indices, probs)
|
|
349
|
+
if child_idx != parent_idx # Exclude self-loops
|
|
350
|
+
]
|
|
351
|
+
|
|
352
|
+
print(
|
|
353
|
+
f"Found {len(predictions)} positive predictions from {len(unique_terms) ** 2 - len(unique_terms)} possible pairs")
|
|
354
|
+
|
|
355
|
+
else:
|
|
356
|
+
# Optional optimization: pre-filter candidates
|
|
357
|
+
print(f"OPTIMIZATION MODE: Filtering to top-{self.top_k_candidates} candidates per term...")
|
|
358
|
+
print("WARNING: This is NOT the original Alexbek approach but an efficiency optimization.")
|
|
359
|
+
candidate_pairs = self._filter_top_k_candidates(unique_terms, self.top_k_candidates)
|
|
360
|
+
print(f"Reduced to {len(candidate_pairs)} candidate pairs")
|
|
361
|
+
|
|
362
|
+
# Score filtered candidates
|
|
363
|
+
print("Scoring filtered candidate pairs...")
|
|
364
|
+
probabilities = self._score_specific_pairs(candidate_pairs)
|
|
365
|
+
|
|
366
|
+
# Apply threshold
|
|
367
|
+
predictions = [
|
|
368
|
+
{
|
|
369
|
+
"parent": parent,
|
|
370
|
+
"child": child,
|
|
371
|
+
"score": float(prob),
|
|
372
|
+
"label": 1,
|
|
373
|
+
}
|
|
374
|
+
for (parent, child), prob in zip(candidate_pairs, probabilities)
|
|
375
|
+
if prob >= self.best_threshold and parent != child
|
|
376
|
+
]
|
|
377
|
+
|
|
378
|
+
print(f"Found {len(predictions)} positive predictions from {len(candidate_pairs)} candidate pairs")
|
|
302
379
|
|
|
303
|
-
|
|
304
|
-
terms: List of unique term strings to embed.
|
|
305
|
-
append_only: If True, only embed terms missing from the cache;
|
|
306
|
-
otherwise (re)encode all provided terms.
|
|
380
|
+
return predictions
|
|
307
381
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
"""
|
|
382
|
+
def _ensure_term_embeddings(self, terms: List[str], append_only: bool = False) -> None:
|
|
383
|
+
"""Encode terms efficiently with batching."""
|
|
311
384
|
if self.embedder is None:
|
|
312
385
|
raise RuntimeError("Call load() before building term embeddings")
|
|
313
386
|
|
|
314
|
-
terms_to_encode = (
|
|
315
|
-
[t for t in terms if t not in self.term_to_vector] if append_only else terms
|
|
316
|
-
)
|
|
387
|
+
terms_to_encode = ([t for t in terms if t not in self.term_to_vector] if append_only else terms)
|
|
317
388
|
if not terms_to_encode:
|
|
318
389
|
return
|
|
319
390
|
|
|
391
|
+
# Batch encode terms with normalization
|
|
320
392
|
embeddings = self.embedder.encode(
|
|
321
393
|
terms_to_encode,
|
|
322
394
|
convert_to_tensor=True,
|
|
323
|
-
normalize_embeddings=
|
|
324
|
-
batch_size=
|
|
325
|
-
show_progress_bar=
|
|
395
|
+
normalize_embeddings=self.normalize_embeddings,
|
|
396
|
+
batch_size=self.inference_batch_size,
|
|
397
|
+
show_progress_bar=True,
|
|
326
398
|
)
|
|
399
|
+
|
|
327
400
|
for term, embedding in zip(terms_to_encode, embeddings):
|
|
328
|
-
self.term_to_vector[term] = embedding.
|
|
401
|
+
self.term_to_vector[term] = embedding.to(self.device)
|
|
329
402
|
|
|
330
|
-
def
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
"""Convert string pairs into aligned embedding tensors on the correct device.
|
|
403
|
+
def _score_specific_pairs(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
404
|
+
"""
|
|
405
|
+
Score only specific (parent, child) pairs efficiently in batches.
|
|
334
406
|
|
|
335
407
|
Args:
|
|
336
|
-
pairs: List of (parent, child)
|
|
408
|
+
pairs: List of (parent, child) tuples to score
|
|
337
409
|
|
|
338
410
|
Returns:
|
|
339
|
-
|
|
340
|
-
|
|
411
|
+
List of probability scores corresponding to input pairs
|
|
412
|
+
"""
|
|
413
|
+
if self.cross_attn_head is None:
|
|
414
|
+
raise RuntimeError("Head not initialized. Call load().")
|
|
415
|
+
|
|
416
|
+
self.cross_attn_head.eval()
|
|
417
|
+
scores: List[float] = []
|
|
418
|
+
|
|
419
|
+
with torch.no_grad():
|
|
420
|
+
for start in tqdm(range(0, len(pairs), self.inference_batch_size), desc="Scoring pairs"):
|
|
421
|
+
chunk = pairs[start: start + self.inference_batch_size]
|
|
422
|
+
child_tensor, parent_tensor = self._pairs_as_tensors(chunk)
|
|
423
|
+
prob = self.cross_attn_head(child_tensor, parent_tensor)
|
|
424
|
+
scores.extend(prob.detach().cpu().tolist())
|
|
341
425
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
426
|
+
return scores
|
|
427
|
+
|
|
428
|
+
def _score_all_pairs_efficient(self, terms: List[str]) -> torch.Tensor:
|
|
345
429
|
"""
|
|
346
|
-
|
|
347
|
-
child_tensor = torch.stack(
|
|
348
|
-
[self.term_to_vector[child] for (_, child) in pairs], dim=0
|
|
349
|
-
).to(self.device)
|
|
350
|
-
# parent embeddings tensor of shape (batch, dim)
|
|
351
|
-
parent_tensor = torch.stack(
|
|
352
|
-
[self.term_to_vector[parent] for (parent, _) in pairs], dim=0
|
|
353
|
-
).to(self.device)
|
|
354
|
-
return child_tensor, parent_tensor
|
|
430
|
+
Efficiently score all pairs using matrix operations (ORIGINAL ALEXBEK APPROACH).
|
|
355
431
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
"""Train the cross-attention head with BCE loss on labeled pairs.
|
|
432
|
+
Returns:
|
|
433
|
+
scores: (n_terms, n_terms) matrix where scores[i,j] = P(terms[j] is parent of terms[i])
|
|
434
|
+
"""
|
|
435
|
+
if self.cross_attn_head is None:
|
|
436
|
+
raise RuntimeError("Head not initialized. Call load().")
|
|
362
437
|
|
|
363
|
-
|
|
364
|
-
|
|
438
|
+
self.cross_attn_head.eval()
|
|
439
|
+
n_terms = len(terms)
|
|
365
440
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
441
|
+
# Stack all embeddings
|
|
442
|
+
all_embeddings = torch.stack([self.term_to_vector[t] for t in terms], dim=0)
|
|
443
|
+
|
|
444
|
+
# Compute scores in chunks to manage memory
|
|
445
|
+
scores_matrix = torch.zeros((n_terms, n_terms), device=self.device)
|
|
446
|
+
|
|
447
|
+
with torch.no_grad():
|
|
448
|
+
chunk_size = self.inference_batch_size
|
|
449
|
+
|
|
450
|
+
progress_bar = tqdm(
|
|
451
|
+
range(0, n_terms, chunk_size),
|
|
452
|
+
desc="Scoring all pairs",
|
|
453
|
+
total=(n_terms + chunk_size - 1) // chunk_size
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
for i in progress_bar:
|
|
457
|
+
end_i = min(i + chunk_size, n_terms)
|
|
458
|
+
child_chunk = all_embeddings[i:end_i]
|
|
459
|
+
|
|
460
|
+
# Score against all parents at once
|
|
461
|
+
child_broadcast = child_chunk.unsqueeze(0)
|
|
462
|
+
parent_broadcast = all_embeddings.unsqueeze(0)
|
|
463
|
+
|
|
464
|
+
chunk_scores = self.cross_attn_head(child_broadcast, parent_broadcast)
|
|
465
|
+
scores_matrix[i:end_i, :] = chunk_scores
|
|
369
466
|
|
|
370
|
-
|
|
371
|
-
|
|
467
|
+
progress_bar.set_postfix({
|
|
468
|
+
'completed': f'{end_i}/{n_terms}',
|
|
469
|
+
'pairs_scored': f'{end_i * n_terms:,}'
|
|
470
|
+
})
|
|
471
|
+
|
|
472
|
+
return scores_matrix
|
|
473
|
+
|
|
474
|
+
def _pairs_as_tensors(self, pairs: List[Tuple[str, str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
475
|
+
"""Convert string pairs into aligned embedding tensors."""
|
|
476
|
+
child_tensor = torch.stack([self.term_to_vector[child] for (_, child) in pairs], dim=0)
|
|
477
|
+
parent_tensor = torch.stack([self.term_to_vector[parent] for (parent, _) in pairs], dim=0)
|
|
478
|
+
return child_tensor, parent_tensor
|
|
479
|
+
|
|
480
|
+
def _optimize_threshold_on_validation(self, val_pairs: List[Tuple[int, Tuple[str, str]]]) -> float:
|
|
481
|
+
"""
|
|
482
|
+
Find optimal threshold that maximizes F1 on validation set.
|
|
483
|
+
This replicates the "Validation-F1" approach from the paper.
|
|
372
484
|
"""
|
|
485
|
+
if not val_pairs:
|
|
486
|
+
return self.prediction_threshold
|
|
487
|
+
|
|
488
|
+
print("Optimizing prediction threshold on validation set...")
|
|
489
|
+
self.cross_attn_head.eval()
|
|
490
|
+
|
|
491
|
+
# Get validation labels and scores
|
|
492
|
+
val_labels = []
|
|
493
|
+
val_scores = []
|
|
494
|
+
|
|
495
|
+
with torch.no_grad():
|
|
496
|
+
for label, (parent, child) in val_pairs:
|
|
497
|
+
val_labels.append(label)
|
|
498
|
+
child_tensor = self.term_to_vector[child].unsqueeze(0)
|
|
499
|
+
parent_tensor = self.term_to_vector[parent].unsqueeze(0)
|
|
500
|
+
score = self.cross_attn_head(child_tensor, parent_tensor).item()
|
|
501
|
+
val_scores.append(score)
|
|
502
|
+
|
|
503
|
+
val_labels = torch.tensor(val_labels)
|
|
504
|
+
val_scores = torch.tensor(val_scores)
|
|
505
|
+
|
|
506
|
+
# Try different thresholds
|
|
507
|
+
best_f1 = 0.0
|
|
508
|
+
best_threshold = 0.5
|
|
509
|
+
|
|
510
|
+
for threshold in torch.linspace(0.1, 0.9, 50):
|
|
511
|
+
preds = (val_scores >= threshold).long()
|
|
512
|
+
|
|
513
|
+
tp = ((preds == 1) & (val_labels == 1)).sum().item()
|
|
514
|
+
fp = ((preds == 1) & (val_labels == 0)).sum().item()
|
|
515
|
+
fn = ((preds == 0) & (val_labels == 1)).sum().item()
|
|
516
|
+
|
|
517
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
518
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
519
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
520
|
+
|
|
521
|
+
if f1 > best_f1:
|
|
522
|
+
best_f1 = f1
|
|
523
|
+
best_threshold = threshold.item()
|
|
524
|
+
|
|
525
|
+
print(f"Optimal threshold: {best_threshold:.3f} (F1: {best_f1:.4f})")
|
|
526
|
+
return best_threshold
|
|
527
|
+
|
|
528
|
+
def _train_cross_attn_head(self,
|
|
529
|
+
positive_pairs: List[Tuple[str, str]],
|
|
530
|
+
negative_pairs: List[Tuple[str, str]]) -> None:
|
|
531
|
+
"""Train the cross-attention head with BCE loss, validation, and early stopping."""
|
|
373
532
|
if self.cross_attn_head is None:
|
|
374
533
|
raise RuntimeError("Head not initialized. Call load().")
|
|
375
534
|
|
|
@@ -380,121 +539,284 @@ class AlexbekCrossAttnLearner(AutoLearner):
|
|
|
380
539
|
weight_decay=self.weight_decay,
|
|
381
540
|
)
|
|
382
541
|
|
|
383
|
-
#
|
|
384
|
-
labeled_pairs: List[Tuple[int, Tuple[str, str]]] = [
|
|
385
|
-
|
|
386
|
-
] + [(0, nc) for nc in negative_pairs]
|
|
542
|
+
# Prepare labeled pairs and split into train/val
|
|
543
|
+
labeled_pairs: List[Tuple[int, Tuple[str, str]]] = [(1, pc) for pc in positive_pairs] + \
|
|
544
|
+
[(0, nc) for nc in negative_pairs]
|
|
387
545
|
random.shuffle(labeled_pairs)
|
|
388
546
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
547
|
+
split_idx = int((1 - self.validation_split) * len(labeled_pairs))
|
|
548
|
+
train_pairs = labeled_pairs[:split_idx]
|
|
549
|
+
val_pairs = labeled_pairs[split_idx:]
|
|
550
|
+
|
|
551
|
+
print(f"Training samples: {len(train_pairs)}, Validation samples: {len(val_pairs)}")
|
|
552
|
+
|
|
553
|
+
# Setup learning rate scheduler
|
|
554
|
+
scheduler = None
|
|
555
|
+
if self.use_lr_scheduler:
|
|
556
|
+
total_steps = (len(train_pairs) // self.batch_size + 1) * self.num_epochs
|
|
557
|
+
warmup_steps = (len(train_pairs) // self.batch_size + 1) * self.warmup_epochs
|
|
558
|
+
|
|
559
|
+
def lr_lambda(step):
|
|
560
|
+
if step < warmup_steps:
|
|
561
|
+
return (step + 1) / (warmup_steps + 1)
|
|
562
|
+
else:
|
|
563
|
+
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
|
564
|
+
return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
|
|
565
|
+
|
|
566
|
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
|
395
567
|
|
|
396
|
-
|
|
568
|
+
def iterate_minibatches(items: List[Tuple[int, Tuple[str, str]]], batch_size: int):
|
|
569
|
+
for start in range(0, len(items), batch_size):
|
|
570
|
+
yield items[start: start + batch_size]
|
|
571
|
+
|
|
572
|
+
# Training loop with early stopping
|
|
573
|
+
best_val_loss = float('inf')
|
|
574
|
+
best_model_state = None
|
|
575
|
+
patience_counter = 0
|
|
576
|
+
metrics_history = []
|
|
577
|
+
global_step = 0
|
|
578
|
+
|
|
579
|
+
for epoch in tqdm(range(self.num_epochs), desc="Training"):
|
|
580
|
+
# Training phase
|
|
581
|
+
self.cross_attn_head.train()
|
|
397
582
|
epoch_loss_sum = 0.0
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
)
|
|
583
|
+
|
|
584
|
+
for minibatch in iterate_minibatches(train_pairs, self.batch_size):
|
|
585
|
+
labels = torch.tensor([y for y, _ in minibatch], dtype=torch.float32, device=self.device)
|
|
402
586
|
string_pairs = [pc for _, pc in minibatch]
|
|
403
587
|
child_tensor, parent_tensor = self._pairs_as_tensors(string_pairs)
|
|
404
588
|
|
|
405
|
-
probs = self.cross_attn_head(child_tensor, parent_tensor)
|
|
406
|
-
loss = F.binary_cross_entropy(probs, labels)
|
|
407
|
-
|
|
408
589
|
optimizer.zero_grad()
|
|
409
|
-
loss.backward()
|
|
410
|
-
optimizer.step()
|
|
411
590
|
|
|
412
|
-
|
|
591
|
+
# Mixed precision training
|
|
592
|
+
if self.use_amp:
|
|
593
|
+
probs = self.cross_attn_head(child_tensor, parent_tensor)
|
|
594
|
+
loss = F.binary_cross_entropy(probs, labels)
|
|
413
595
|
|
|
414
|
-
|
|
415
|
-
"""Compute probability scores for (parent, child) pairs.
|
|
596
|
+
self.scaler.scale(loss).backward()
|
|
416
597
|
|
|
417
|
-
|
|
418
|
-
|
|
598
|
+
if self.gradient_clip > 0:
|
|
599
|
+
self.scaler.unscale_(optimizer)
|
|
600
|
+
torch.nn.utils.clip_grad_norm_(
|
|
601
|
+
self.cross_attn_head.parameters(),
|
|
602
|
+
self.gradient_clip
|
|
603
|
+
)
|
|
419
604
|
|
|
420
|
-
|
|
421
|
-
|
|
605
|
+
self.scaler.step(optimizer)
|
|
606
|
+
self.scaler.update()
|
|
607
|
+
else:
|
|
608
|
+
probs = self.cross_attn_head(child_tensor, parent_tensor)
|
|
609
|
+
loss = F.binary_cross_entropy(probs, labels)
|
|
422
610
|
|
|
423
|
-
|
|
424
|
-
RuntimeError: If the head has not been initialized (call `load()`).
|
|
425
|
-
"""
|
|
426
|
-
if self.cross_attn_head is None:
|
|
427
|
-
raise RuntimeError("Head not initialized. Call load().")
|
|
611
|
+
loss.backward()
|
|
428
612
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
scores.extend(prob.detach().cpu().tolist())
|
|
437
|
-
return scores
|
|
613
|
+
if self.gradient_clip > 0:
|
|
614
|
+
torch.nn.utils.clip_grad_norm_(
|
|
615
|
+
self.cross_attn_head.parameters(),
|
|
616
|
+
self.gradient_clip
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
optimizer.step()
|
|
438
620
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
621
|
+
if scheduler is not None:
|
|
622
|
+
scheduler.step()
|
|
623
|
+
global_step += 1
|
|
624
|
+
|
|
625
|
+
epoch_loss_sum += float(loss.item()) * len(minibatch)
|
|
443
626
|
|
|
444
|
-
|
|
445
|
-
|
|
627
|
+
avg_train_loss = epoch_loss_sum / len(train_pairs)
|
|
628
|
+
|
|
629
|
+
# Validation phase
|
|
630
|
+
self.cross_attn_head.eval()
|
|
631
|
+
val_loss_sum = 0.0
|
|
632
|
+
|
|
633
|
+
with torch.no_grad():
|
|
634
|
+
for minibatch in iterate_minibatches(val_pairs, self.batch_size):
|
|
635
|
+
labels = torch.tensor([y for y, _ in minibatch], dtype=torch.float32, device=self.device)
|
|
636
|
+
string_pairs = [pc for _, pc in minibatch]
|
|
637
|
+
child_tensor, parent_tensor = self._pairs_as_tensors(string_pairs)
|
|
638
|
+
|
|
639
|
+
probs = self.cross_attn_head(child_tensor, parent_tensor)
|
|
640
|
+
loss = F.binary_cross_entropy(probs, labels)
|
|
641
|
+
val_loss_sum += float(loss.item()) * len(minibatch)
|
|
642
|
+
|
|
643
|
+
avg_val_loss = val_loss_sum / len(val_pairs)
|
|
644
|
+
|
|
645
|
+
# Track metrics
|
|
646
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
647
|
+
metrics = {
|
|
648
|
+
'epoch': epoch + 1,
|
|
649
|
+
'train_loss': avg_train_loss,
|
|
650
|
+
'val_loss': avg_val_loss,
|
|
651
|
+
'learning_rate': current_lr,
|
|
652
|
+
'timestamp': datetime.now().isoformat()
|
|
653
|
+
}
|
|
654
|
+
metrics_history.append(metrics)
|
|
655
|
+
|
|
656
|
+
# Save best model
|
|
657
|
+
if avg_val_loss < best_val_loss:
|
|
658
|
+
best_val_loss = avg_val_loss
|
|
659
|
+
best_model_state = self.cross_attn_head.state_dict()
|
|
660
|
+
patience_counter = 0
|
|
661
|
+
print(f"Epoch {epoch + 1}/{self.num_epochs} | "
|
|
662
|
+
f"Train Loss: {avg_train_loss:.4f} | "
|
|
663
|
+
f"Val Loss: {avg_val_loss:.4f} ⭐ (Best) | "
|
|
664
|
+
f"LR: {current_lr:.6f}")
|
|
665
|
+
else:
|
|
666
|
+
patience_counter += 1
|
|
667
|
+
print(f"Epoch {epoch + 1}/{self.num_epochs} | "
|
|
668
|
+
f"Train Loss: {avg_train_loss:.4f} | "
|
|
669
|
+
f"Val Loss: {avg_val_loss:.4f} | "
|
|
670
|
+
f"LR: {current_lr:.6f} | "
|
|
671
|
+
f"Patience: {patience_counter}/{self.patience}")
|
|
672
|
+
|
|
673
|
+
# Early stopping
|
|
674
|
+
if patience_counter >= self.patience:
|
|
675
|
+
print(f"Early stopping triggered at epoch {epoch + 1}")
|
|
676
|
+
break
|
|
677
|
+
|
|
678
|
+
# Restore best model
|
|
679
|
+
if best_model_state is not None:
|
|
680
|
+
self.cross_attn_head.load_state_dict(best_model_state)
|
|
681
|
+
print(f"Restored best model with validation loss: {best_val_loss:.4f}")
|
|
682
|
+
|
|
683
|
+
# Optimize threshold on validation set if requested
|
|
684
|
+
if self.optimize_threshold_on_val and val_pairs:
|
|
685
|
+
self.best_threshold = self._optimize_threshold_on_validation(val_pairs)
|
|
686
|
+
else:
|
|
687
|
+
self.best_threshold = self.prediction_threshold
|
|
688
|
+
|
|
689
|
+
# Save training metrics
|
|
690
|
+
metrics_path = os.path.join(self.output_dir, 'training_metrics.json')
|
|
691
|
+
with open(metrics_path, 'w') as f:
|
|
692
|
+
json.dump(metrics_history, f, indent=2)
|
|
693
|
+
|
|
694
|
+
def _filter_top_k_candidates(self, terms: List[str], top_k: int) -> List[Tuple[str, str]]:
|
|
695
|
+
"""
|
|
696
|
+
OPTIONAL OPTIMIZATION (NOT in original Alexbek paper):
|
|
697
|
+
Filter candidate pairs to only include top-k most similar terms based on cosine similarity.
|
|
698
|
+
Memory-efficient chunked implementation.
|
|
446
699
|
|
|
447
700
|
Args:
|
|
448
|
-
|
|
701
|
+
terms: List of unique terms
|
|
702
|
+
top_k: Number of most similar candidates to keep per term
|
|
449
703
|
|
|
450
704
|
Returns:
|
|
451
|
-
|
|
452
|
-
- `pairs` is a list of (parent, child) strings,
|
|
453
|
-
- `terms` is a sorted list of unique term strings (parents ∪ children).
|
|
705
|
+
List of (parent, child) candidate pairs
|
|
454
706
|
"""
|
|
707
|
+
n_terms = len(terms)
|
|
708
|
+
|
|
709
|
+
# Stack all embeddings and normalize for cosine similarity
|
|
710
|
+
all_embeddings = torch.stack([self.term_to_vector[t] for t in terms], dim=0)
|
|
711
|
+
normalized_embeddings = F.normalize(all_embeddings, p=2, dim=1)
|
|
712
|
+
|
|
713
|
+
candidate_pairs = []
|
|
714
|
+
|
|
715
|
+
# Process in chunks to avoid OOM for large taxonomies
|
|
716
|
+
chunk_size = min(1000, n_terms)
|
|
717
|
+
|
|
718
|
+
print("Finding top-k similar terms for each term...")
|
|
719
|
+
for child_start in tqdm(range(0, n_terms, chunk_size), desc="Filtering candidates"):
|
|
720
|
+
child_end = min(child_start + chunk_size, n_terms)
|
|
721
|
+
child_chunk = normalized_embeddings[child_start:child_end]
|
|
722
|
+
|
|
723
|
+
# Compute similarities for this chunk
|
|
724
|
+
similarities = torch.mm(child_chunk, normalized_embeddings.t())
|
|
725
|
+
|
|
726
|
+
# Get top-k+1 for each child in chunk (to exclude self if needed)
|
|
727
|
+
top_k_values, top_k_indices = torch.topk(similarities, min(top_k + 1, n_terms), dim=1)
|
|
728
|
+
|
|
729
|
+
# Add pairs (excluding self-loops)
|
|
730
|
+
for local_idx, child_idx in enumerate(range(child_start, child_end)):
|
|
731
|
+
for parent_idx in top_k_indices[local_idx].cpu().tolist():
|
|
732
|
+
if parent_idx != child_idx:
|
|
733
|
+
candidate_pairs.append((terms[parent_idx], terms[child_idx]))
|
|
734
|
+
|
|
735
|
+
return candidate_pairs
|
|
736
|
+
|
|
737
|
+
def _extract_parent_child_pairs_and_terms(self, data: Any, test: bool) -> Tuple[List[Tuple[str, str]], List[str]]:
|
|
738
|
+
"""Extract (parent, child) edges and unique terms from ontology data."""
|
|
455
739
|
parent_child_pairs: List[Tuple[str, str]] = []
|
|
456
740
|
unique_terms = set()
|
|
741
|
+
|
|
457
742
|
for edge in getattr(data, "type_taxonomies").taxonomies:
|
|
458
743
|
parent, child = str(edge.parent), str(edge.child)
|
|
459
|
-
|
|
744
|
+
if not test:
|
|
745
|
+
parent_child_pairs.append((parent, child))
|
|
460
746
|
unique_terms.add(parent)
|
|
461
747
|
unique_terms.add(child)
|
|
748
|
+
|
|
749
|
+
if test:
|
|
750
|
+
# In test mode, return empty pairs - will score all pairs in _taxonomy_discovery
|
|
751
|
+
pass
|
|
752
|
+
|
|
462
753
|
return parent_child_pairs, sorted(unique_terms)
|
|
463
754
|
|
|
464
755
|
def _sample_negative_pairs(
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
756
|
+
self,
|
|
757
|
+
positive_pairs: List[Tuple[str, str]],
|
|
758
|
+
terms: List[str],
|
|
759
|
+
ratio: float = 1.0,
|
|
760
|
+
seed: int = 42,
|
|
470
761
|
) -> List[Tuple[str, str]]:
|
|
471
|
-
"""
|
|
472
|
-
|
|
473
|
-
Sampling is uniform over the Cartesian product of `terms` excluding
|
|
474
|
-
(x, x) self-pairs and any pair found in `positive_pairs`.
|
|
475
|
-
|
|
476
|
-
Args:
|
|
477
|
-
positive_pairs: Known positive edges to exclude.
|
|
478
|
-
terms: Candidate vocabulary (parents ∪ children).
|
|
479
|
-
ratio: Number of negatives per positive to draw.
|
|
480
|
-
seed: RNG seed used for reproducible sampling.
|
|
762
|
+
"""
|
|
763
|
+
Sample negative pairs.
|
|
481
764
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
`int(len(positive_pairs) * ratio)`.
|
|
765
|
+
Original approach: All random negatives (hard_negative_ratio=0.0)
|
|
766
|
+
Optional: Mix of hard negatives and random negatives (hard_negative_ratio>0.0)
|
|
485
767
|
"""
|
|
486
768
|
random.seed(seed)
|
|
487
769
|
term_list = list(terms)
|
|
488
770
|
positive_set = set(positive_pairs)
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
771
|
+
|
|
772
|
+
target_count = int(len(positive_pairs) * ratio)
|
|
773
|
+
hard_count = int(target_count * self.hard_negative_ratio) if self.hard_negative_ratio > 0 else 0
|
|
774
|
+
random_count = target_count - hard_count
|
|
775
|
+
|
|
776
|
+
negatives = []
|
|
777
|
+
|
|
778
|
+
# Hard negatives: pairs with high embedding similarity but not in taxonomy
|
|
779
|
+
if hard_count > 0:
|
|
780
|
+
print(f"Sampling {hard_count} hard negatives based on embedding similarity...")
|
|
781
|
+
all_embeddings = torch.stack([self.term_to_vector[t] for t in term_list])
|
|
782
|
+
normalized_embeddings = F.normalize(all_embeddings, p=2, dim=1)
|
|
783
|
+
similarity_matrix = torch.mm(normalized_embeddings, normalized_embeddings.t())
|
|
784
|
+
|
|
785
|
+
# For each term, get candidates sorted by similarity
|
|
786
|
+
for i in range(len(term_list)):
|
|
787
|
+
if len(negatives) >= hard_count:
|
|
788
|
+
break
|
|
789
|
+
|
|
790
|
+
similarities = similarity_matrix[i]
|
|
791
|
+
sorted_indices = torch.argsort(similarities, descending=True)
|
|
792
|
+
|
|
793
|
+
for j in sorted_indices:
|
|
794
|
+
j_idx = j.item()
|
|
795
|
+
if i == j_idx:
|
|
796
|
+
continue
|
|
797
|
+
candidate = (term_list[j_idx], term_list[i])
|
|
798
|
+
if candidate not in positive_set and candidate not in negatives:
|
|
799
|
+
negatives.append(candidate)
|
|
800
|
+
if len(negatives) >= hard_count:
|
|
801
|
+
break
|
|
802
|
+
# Random negatives (ORIGINAL ALEXBEK APPROACH)
|
|
803
|
+
if random_count > 0:
|
|
804
|
+
print(f"Sampling {random_count} random negatives...")
|
|
805
|
+
max_attempts = random_count * 10
|
|
806
|
+
attempts = 0
|
|
807
|
+
|
|
808
|
+
while len(negatives) < target_count and attempts < max_attempts:
|
|
809
|
+
parent = random.choice(term_list)
|
|
810
|
+
child = random.choice(term_list)
|
|
811
|
+
attempts += 1
|
|
812
|
+
|
|
813
|
+
if parent == child:
|
|
814
|
+
continue
|
|
815
|
+
candidate = (parent, child)
|
|
816
|
+
if candidate not in positive_set and candidate not in negatives:
|
|
817
|
+
negatives.append(candidate)
|
|
818
|
+
if hard_count > 0:
|
|
819
|
+
print(f"Sampled {len(negatives)} negative pairs ({hard_count} hard, {len(negatives) - hard_count} random)")
|
|
820
|
+
else:
|
|
821
|
+
print(f"Sampled {len(negatives)} random negative pairs (original approach)")
|
|
500
822
|
return negatives
|