sawnergy 1.0.3__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sawnergy/__init__.py +3 -1
- sawnergy/embedding/SGNS_pml.py +324 -51
- sawnergy/embedding/SGNS_torch.py +282 -39
- sawnergy/embedding/__init__.py +26 -1
- sawnergy/embedding/embedder.py +426 -203
- sawnergy/embedding/visualizer.py +251 -0
- sawnergy/logging_util.py +1 -1
- sawnergy/rin/rin_builder.py +4 -4
- sawnergy/visual/visualizer.py +6 -6
- sawnergy/visual/visualizer_util.py +3 -0
- sawnergy/walks/walker.py +43 -22
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/METADATA +91 -57
- sawnergy-1.0.9.dist-info/RECORD +23 -0
- sawnergy-1.0.3.dist-info/RECORD +0 -22
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/WHEEL +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/licenses/LICENSE +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/licenses/NOTICE +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/top_level.txt +0 -0
sawnergy/embedding/SGNS_torch.py
CHANGED
|
@@ -27,46 +27,77 @@ class SGNS_Torch:
|
|
|
27
27
|
def __init__(self,
|
|
28
28
|
V: int,
|
|
29
29
|
D: int,
|
|
30
|
+
in_weights: torch.Tensor | np.ndarray | None = None,
|
|
31
|
+
out_weights: torch.Tensor | np.ndarray | None = None,
|
|
30
32
|
*,
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
"""
|
|
33
|
+
seed: int | None = None,
|
|
34
|
+
optim: Type[Optimizer] = torch.optim.SGD,
|
|
35
|
+
optim_kwargs: dict | None = None,
|
|
36
|
+
lr_sched: Type[LRScheduler] | None = None,
|
|
37
|
+
lr_sched_kwargs: dict | None = None,
|
|
38
|
+
device: str | None = None):
|
|
39
|
+
"""Initialize SGNS (negative sampling) in PyTorch.
|
|
40
|
+
|
|
41
|
+
Shapes:
|
|
42
|
+
- Embedding tables:
|
|
43
|
+
in_weights: (V, D) or None — row i is the “input” vector for token i.
|
|
44
|
+
out_weights: (V, D) or None — row i is the “output” vector for token i.
|
|
45
|
+
|
|
38
46
|
Args:
|
|
39
|
-
V: Vocabulary size (number of nodes).
|
|
47
|
+
V: Vocabulary size (number of nodes/tokens).
|
|
40
48
|
D: Embedding dimensionality.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
49
|
+
in_weights: Optional starting input-embedding matrix of shape (V, D).
|
|
50
|
+
out_weights: Optional starting output-embedding matrix of shape (V, D).
|
|
51
|
+
seed: Optional RNG seed for PyTorch (controls init, sampling, and shuffles).
|
|
52
|
+
optim: Optimizer class to instantiate. Defaults to plain SGD.
|
|
53
|
+
optim_kwargs: Keyword arguments for the optimizer. Defaults to {"lr": 0.1}.
|
|
44
54
|
lr_sched: Optional learning-rate scheduler class.
|
|
45
|
-
lr_sched_kwargs: Keyword arguments for the scheduler.
|
|
46
|
-
device: Target device string (e.g.
|
|
55
|
+
lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
|
|
56
|
+
device: Target device string (e.g. "cuda"). Defaults to CUDA if available, else CPU.
|
|
47
57
|
"""
|
|
48
|
-
|
|
49
|
-
raise ValueError("optim_kwargs must be provided")
|
|
58
|
+
optim_kwargs = optim_kwargs or {"lr": 0.1}
|
|
50
59
|
if lr_sched is not None and lr_sched_kwargs is None:
|
|
51
60
|
raise ValueError("lr_sched_kwargs required when lr_sched is provided")
|
|
61
|
+
|
|
52
62
|
self.V, self.D = int(V), int(D)
|
|
53
|
-
|
|
63
|
+
|
|
64
|
+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
54
65
|
self.device = torch.device(resolved_device)
|
|
55
|
-
_logger.info("SGNS_Torch init: V=%d D=%d device=%s seed=%s", self.V, self.D, self.device, seed)
|
|
56
66
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
67
|
+
# Seed torch
|
|
68
|
+
self.seed = None if seed is None else int(seed)
|
|
69
|
+
if self.seed is not None:
|
|
70
|
+
torch.manual_seed(self.seed)
|
|
60
71
|
if self.device.type == "cuda":
|
|
61
|
-
torch.cuda.manual_seed_all(
|
|
72
|
+
torch.cuda.manual_seed_all(self.seed)
|
|
62
73
|
|
|
63
74
|
# two embeddings as in/out matrices
|
|
64
|
-
self.in_emb = nn.Embedding(self.V, self.D)
|
|
65
|
-
self.out_emb = nn.Embedding(self.V, self.D)
|
|
75
|
+
self.in_emb = nn.Embedding(self.V, self.D, device=self.device)
|
|
76
|
+
self.out_emb = nn.Embedding(self.V, self.D, device=self.device)
|
|
77
|
+
|
|
78
|
+
# init / warm-start
|
|
79
|
+
with torch.no_grad():
|
|
80
|
+
if in_weights is not None:
|
|
81
|
+
w = torch.as_tensor(in_weights, dtype=torch.float32, device=self.device)
|
|
82
|
+
if w.shape != (self.V, self.D):
|
|
83
|
+
raise ValueError(f"in_weights must be (V,D); got {tuple(w.shape)}")
|
|
84
|
+
self.in_emb.weight.copy_(w)
|
|
85
|
+
else:
|
|
86
|
+
nn.init.uniform_(self.in_emb.weight, -0.5 / self.D, 0.5 / self.D)
|
|
87
|
+
|
|
88
|
+
if out_weights is not None:
|
|
89
|
+
w = torch.as_tensor(out_weights, dtype=torch.float32, device=self.device)
|
|
90
|
+
if w.shape != (self.V, self.D):
|
|
91
|
+
raise ValueError(f"out_weights must be (V,D); got {tuple(w.shape)}")
|
|
92
|
+
self.out_emb.weight.copy_(w)
|
|
93
|
+
else:
|
|
94
|
+
nn.init.zeros_(self.out_emb.weight)
|
|
66
95
|
|
|
67
96
|
self.to(self.device)
|
|
97
|
+
_logger.info("SGNS_Torch init: V=%d D=%d device=%s seed=%s", self.V, self.D, self.device, self.seed)
|
|
68
98
|
|
|
69
99
|
params = list(self.in_emb.parameters()) + list(self.out_emb.parameters())
|
|
100
|
+
# optimizer / scheduler
|
|
70
101
|
self.opt = optim(params=params, **optim_kwargs)
|
|
71
102
|
self.lr_sched = lr_sched(self.opt, **lr_sched_kwargs) if lr_sched is not None else None
|
|
72
103
|
|
|
@@ -74,7 +105,17 @@ class SGNS_Torch:
|
|
|
74
105
|
center: torch.Tensor,
|
|
75
106
|
pos: torch.Tensor,
|
|
76
107
|
neg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
108
|
+
"""Compute positive/negative logits for SGNS.
|
|
77
109
|
|
|
110
|
+
Inputs:
|
|
111
|
+
center: int tensor of shape (B,), values in [0, V)
|
|
112
|
+
pos: int tensor of shape (B,), values in [0, V)
|
|
113
|
+
neg: int tensor of shape (B, K), values in [0, V)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
pos_logits: (B,)
|
|
117
|
+
neg_logits: (B, K)
|
|
118
|
+
"""
|
|
78
119
|
center = center.to(self.device, dtype=torch.long)
|
|
79
120
|
pos = pos.to(self.device, dtype=torch.long)
|
|
80
121
|
neg = neg.to(self.device, dtype=torch.long)
|
|
@@ -83,9 +124,8 @@ class SGNS_Torch:
|
|
|
83
124
|
pe = self.out_emb(pos) # (B, D)
|
|
84
125
|
ne = self.out_emb(neg) # (B, K, D)
|
|
85
126
|
|
|
86
|
-
pos_logits = (c * pe).sum(dim=-1)
|
|
87
|
-
neg_logits = (c.unsqueeze(1) * ne).sum(dim=-1)
|
|
88
|
-
|
|
127
|
+
pos_logits = (c * pe).sum(dim=-1) # (B,)
|
|
128
|
+
neg_logits = (c.unsqueeze(1) * ne).sum(dim=-1) # (B, K)
|
|
89
129
|
return pos_logits, neg_logits
|
|
90
130
|
|
|
91
131
|
__call__ = predict
|
|
@@ -112,15 +152,27 @@ class SGNS_Torch:
|
|
|
112
152
|
idx = np.arange(N)
|
|
113
153
|
|
|
114
154
|
noise_probs = torch.as_tensor(noise_dist, dtype=torch.float32, device=self.device)
|
|
155
|
+
# normalize if slightly off; enforce nonnegativity + finite sum
|
|
156
|
+
if (noise_probs < 0).any():
|
|
157
|
+
raise ValueError("noise_dist has negative entries")
|
|
158
|
+
s = noise_probs.sum()
|
|
159
|
+
if not torch.isfinite(s) or float(s.item()) <= 0.0:
|
|
160
|
+
raise ValueError("noise_dist must have positive finite sum")
|
|
161
|
+
if abs(float(s.item()) - 1.0) > 1e-6:
|
|
162
|
+
noise_probs = noise_probs / s
|
|
115
163
|
|
|
116
164
|
for epoch in range(1, int(num_epochs) + 1):
|
|
117
165
|
epoch_loss = 0.0
|
|
118
166
|
batches = 0
|
|
167
|
+
|
|
119
168
|
if shuffle_data:
|
|
120
|
-
|
|
169
|
+
if self.seed is None:
|
|
170
|
+
np.random.shuffle(idx)
|
|
171
|
+
else:
|
|
172
|
+
np.random.default_rng(self.seed + epoch).shuffle(idx)
|
|
121
173
|
|
|
122
|
-
for
|
|
123
|
-
take = idx[
|
|
174
|
+
for s_ in range(0, N, int(batch_size)):
|
|
175
|
+
take = idx[s_:s_+int(batch_size)]
|
|
124
176
|
if take.size == 0:
|
|
125
177
|
continue
|
|
126
178
|
K = int(num_negative_samples)
|
|
@@ -128,19 +180,199 @@ class SGNS_Torch:
|
|
|
128
180
|
|
|
129
181
|
cen = torch.as_tensor(centers[take], dtype=torch.long, device=self.device) # (B,)
|
|
130
182
|
pos = torch.as_tensor(contexts[take], dtype=torch.long, device=self.device) # (B,)
|
|
131
|
-
neg = torch.multinomial(noise_probs, num_samples=B * K, replacement=True).view(B, K) # (B,K)
|
|
183
|
+
neg = torch.multinomial(noise_probs, num_samples=B * K, replacement=True).view(B, K) # (B,K)
|
|
132
184
|
|
|
133
185
|
pos_logits, neg_logits = self(cen, pos, neg)
|
|
134
186
|
|
|
135
|
-
# BCE(+)
|
|
136
187
|
y_pos = torch.ones_like(pos_logits)
|
|
137
|
-
loss_pos = bce(pos_logits, y_pos)
|
|
138
|
-
|
|
139
|
-
# BCE(-):
|
|
140
188
|
y_neg = torch.zeros_like(neg_logits)
|
|
189
|
+
loss_pos = bce(pos_logits, y_pos)
|
|
141
190
|
loss_neg = bce(neg_logits, y_neg)
|
|
142
191
|
|
|
143
|
-
loss = loss_pos + loss_neg
|
|
192
|
+
loss = loss_pos + K * loss_neg
|
|
193
|
+
|
|
194
|
+
self.opt.zero_grad(set_to_none=True)
|
|
195
|
+
loss.backward()
|
|
196
|
+
self.opt.step()
|
|
197
|
+
|
|
198
|
+
if lr_step_per_batch and self.lr_sched is not None:
|
|
199
|
+
self.lr_sched.step()
|
|
200
|
+
|
|
201
|
+
epoch_loss += float(loss.detach().cpu().item())
|
|
202
|
+
batches += 1
|
|
203
|
+
_logger.debug("Epoch %d batch %d loss=%.6f", epoch, batches, loss.item())
|
|
204
|
+
|
|
205
|
+
if not lr_step_per_batch and self.lr_sched is not None:
|
|
206
|
+
self.lr_sched.step()
|
|
207
|
+
|
|
208
|
+
mean_loss = epoch_loss / max(batches, 1)
|
|
209
|
+
_logger.info("Epoch %d/%d mean_loss=%.6f", epoch, num_epochs, mean_loss)
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def in_embeddings(self) -> np.ndarray:
|
|
213
|
+
W = self.in_emb.weight.detach().cpu().numpy() # (V, D)
|
|
214
|
+
_logger.debug("In emb shape: %s", W.shape)
|
|
215
|
+
return W
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def out_embeddings(self) -> np.ndarray:
|
|
219
|
+
W = self.out_emb.weight.detach().cpu().numpy() # (V, D)
|
|
220
|
+
_logger.debug("Out emb shape: %s", W.shape)
|
|
221
|
+
return W
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def avg_embeddings(self) -> np.ndarray:
|
|
225
|
+
return 0.5 * (self.in_embeddings + self.out_embeddings)
|
|
226
|
+
|
|
227
|
+
# tiny helper for device move
|
|
228
|
+
def to(self, device):
|
|
229
|
+
self.in_emb.to(device)
|
|
230
|
+
self.out_emb.to(device)
|
|
231
|
+
return self
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SG_Torch:
|
|
235
|
+
"""PyTorch implementation of Skip-Gram (full softmax, **no biases**).
|
|
236
|
+
|
|
237
|
+
This variant uses **no bias terms**: both projections are pure linear maps.
|
|
238
|
+
|
|
239
|
+
Computation:
|
|
240
|
+
x = one_hot(center, V) # (B, V)
|
|
241
|
+
y = x @ W_in # (B, D), with W_in ∈ R^{VxD}
|
|
242
|
+
logits = y @ W_out # (B, V), with W_out ∈ R^{DxV}
|
|
243
|
+
loss = CrossEntropyLoss(logits, context)
|
|
244
|
+
|
|
245
|
+
Embeddings:
|
|
246
|
+
- Input embeddings = rows of W_in → shape (V, D)
|
|
247
|
+
- Output embeddings = rows of W_outᵀ → shape (V, D)
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self,
|
|
251
|
+
V: int,
|
|
252
|
+
D: int,
|
|
253
|
+
in_weights: torch.Tensor | np.ndarray | None = None,
|
|
254
|
+
out_weights: torch.Tensor | np.ndarray | None = None,
|
|
255
|
+
*,
|
|
256
|
+
seed: int | None = None,
|
|
257
|
+
optim: Type[Optimizer] = torch.optim.SGD,
|
|
258
|
+
optim_kwargs: dict | None = None,
|
|
259
|
+
lr_sched: Type[LRScheduler] | None = None,
|
|
260
|
+
lr_sched_kwargs: dict | None = None,
|
|
261
|
+
device: str | None = None):
|
|
262
|
+
"""Initialize the plain Skip-Gram (full softmax, **no biases**) model in PyTorch.
|
|
263
|
+
|
|
264
|
+
Shapes:
|
|
265
|
+
- Linear maps (no bias):
|
|
266
|
+
W_in: (V, D) — rows are input embeddings for tokens.
|
|
267
|
+
W_out: (D, V) — maps D→V; rows of W_outᵀ are output embeddings.
|
|
268
|
+
|
|
269
|
+
- Warm-starts:
|
|
270
|
+
in_weights: (V, D) or None — copied into W_in if provided.
|
|
271
|
+
out_weights: (D, V) or None — copied into W_out if provided.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
V: Vocabulary size (number of nodes/tokens).
|
|
275
|
+
D: Embedding dimensionality.
|
|
276
|
+
in_weights: Optional starting matrix for W_in with shape (V, D).
|
|
277
|
+
out_weights: Optional starting matrix for W_out with shape (D, V).
|
|
278
|
+
seed: Optional RNG seed for reproducibility.
|
|
279
|
+
optim: Optimizer class to instantiate. Defaults to :class:`torch.optim.SGD`.
|
|
280
|
+
optim_kwargs: Keyword args for the optimizer. Defaults to ``{"lr": 0.1}``.
|
|
281
|
+
lr_sched: Optional learning-rate scheduler class.
|
|
282
|
+
lr_sched_kwargs: Keyword args for the scheduler (required if ``lr_sched`` is provided).
|
|
283
|
+
device: Target device string (e.g., ``"cuda"``). Defaults to CUDA if available, else CPU.
|
|
284
|
+
|
|
285
|
+
Notes:
|
|
286
|
+
The encoder/decoder are **bias-free** linear layers acting on one-hot centers:
|
|
287
|
+
• ``in_emb = nn.Linear(V, D, bias=False)``
|
|
288
|
+
• ``out_emb = nn.Linear(D, V, bias=False)``
|
|
289
|
+
Forward pass produces vocabulary-sized logits and is trained with CrossEntropyLoss.
|
|
290
|
+
"""
|
|
291
|
+
optim_kwargs = optim_kwargs or {"lr": 0.1}
|
|
292
|
+
if lr_sched is not None and lr_sched_kwargs is None:
|
|
293
|
+
raise ValueError("lr_sched_kwargs required when lr_sched is provided")
|
|
294
|
+
|
|
295
|
+
self.V, self.D = int(V), int(D)
|
|
296
|
+
|
|
297
|
+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
298
|
+
self.device = torch.device(resolved_device)
|
|
299
|
+
|
|
300
|
+
# Seed torch (no global NumPy seeding)
|
|
301
|
+
self.seed = None if seed is None else int(seed)
|
|
302
|
+
if self.seed is not None:
|
|
303
|
+
torch.manual_seed(self.seed)
|
|
304
|
+
if self.device.type == "cuda":
|
|
305
|
+
torch.cuda.manual_seed_all(self.seed)
|
|
306
|
+
|
|
307
|
+
self.in_emb = nn.Linear(self.V, self.D, bias=False, device=self.device)
|
|
308
|
+
self.out_emb = nn.Linear(self.D, self.V, bias=False, device=self.device)
|
|
309
|
+
|
|
310
|
+
# warm-starts (note Linear weights are (out_features, in_features))
|
|
311
|
+
with torch.no_grad():
|
|
312
|
+
if in_weights is not None:
|
|
313
|
+
w_in = torch.as_tensor(in_weights, dtype=torch.float32, device=self.device)
|
|
314
|
+
if w_in.shape != (self.V, self.D):
|
|
315
|
+
raise ValueError(f"in_weights must be (V,D); got {tuple(w_in.shape)}")
|
|
316
|
+
self.in_emb.weight.copy_(w_in.T) # (D,V)
|
|
317
|
+
# else: use default PyTorch init
|
|
318
|
+
|
|
319
|
+
if out_weights is not None:
|
|
320
|
+
w_out = torch.as_tensor(out_weights, dtype=torch.float32, device=self.device)
|
|
321
|
+
if w_out.shape != (self.D, self.V):
|
|
322
|
+
raise ValueError(f"out_weights must be (D,V); got {tuple(w_out.shape)}")
|
|
323
|
+
self.out_emb.weight.copy_(w_out) # (V,D) weight is (V,D) because (out=in V, in=D)
|
|
324
|
+
# else: default init
|
|
325
|
+
|
|
326
|
+
self.to(self.device)
|
|
327
|
+
_logger.info("SG_Torch init: V=%d D=%d device=%s seed=%s", self.V, self.D, self.device, self.seed)
|
|
328
|
+
|
|
329
|
+
params = list(self.in_emb.parameters()) + list(self.out_emb.parameters())
|
|
330
|
+
# optimizer / scheduler
|
|
331
|
+
self.opt = optim(params=params, **optim_kwargs)
|
|
332
|
+
self.lr_sched = lr_sched(self.opt, **lr_sched_kwargs) if lr_sched is not None else None
|
|
333
|
+
|
|
334
|
+
def predict(self, center: torch.Tensor) -> torch.Tensor:
|
|
335
|
+
center = center.to(self.device, dtype=torch.long)
|
|
336
|
+
c = nn.functional.one_hot(center, num_classes=self.V).to(dtype=torch.float32, device=self.device)
|
|
337
|
+
y = self.in_emb(c)
|
|
338
|
+
z = self.out_emb(y)
|
|
339
|
+
return z
|
|
340
|
+
|
|
341
|
+
__call__ = predict
|
|
342
|
+
|
|
343
|
+
def fit(self,
|
|
344
|
+
centers: np.ndarray,
|
|
345
|
+
contexts: np.ndarray,
|
|
346
|
+
num_epochs: int,
|
|
347
|
+
batch_size: int,
|
|
348
|
+
shuffle_data: bool,
|
|
349
|
+
lr_step_per_batch: bool,
|
|
350
|
+
**_ignore):
|
|
351
|
+
cce = nn.CrossEntropyLoss(reduction="mean")
|
|
352
|
+
|
|
353
|
+
N = centers.shape[0]
|
|
354
|
+
idx = np.arange(N)
|
|
355
|
+
|
|
356
|
+
for epoch in range(1, int(num_epochs) + 1):
|
|
357
|
+
epoch_loss = 0.0
|
|
358
|
+
batches = 0
|
|
359
|
+
|
|
360
|
+
if shuffle_data:
|
|
361
|
+
if self.seed is None:
|
|
362
|
+
np.random.shuffle(idx)
|
|
363
|
+
else:
|
|
364
|
+
np.random.default_rng(self.seed + epoch).shuffle(idx)
|
|
365
|
+
|
|
366
|
+
for s in range(0, N, int(batch_size)):
|
|
367
|
+
take = idx[s:s+int(batch_size)]
|
|
368
|
+
if take.size == 0:
|
|
369
|
+
continue
|
|
370
|
+
|
|
371
|
+
cen = torch.as_tensor(centers[take], dtype=torch.long, device=self.device)
|
|
372
|
+
ctx = torch.as_tensor(contexts[take], dtype=torch.long, device=self.device)
|
|
373
|
+
|
|
374
|
+
logits = self(cen)
|
|
375
|
+
loss = cce(logits, ctx)
|
|
144
376
|
|
|
145
377
|
self.opt.zero_grad(set_to_none=True)
|
|
146
378
|
loss.backward()
|
|
@@ -160,9 +392,20 @@ class SGNS_Torch:
|
|
|
160
392
|
_logger.info("Epoch %d/%d mean_loss=%.6f", epoch, num_epochs, mean_loss)
|
|
161
393
|
|
|
162
394
|
@property
|
|
163
|
-
def
|
|
164
|
-
|
|
165
|
-
|
|
395
|
+
def in_embeddings(self) -> np.ndarray:
|
|
396
|
+
W = self.in_emb.weight.detach().T.cpu().numpy() # (V, D)
|
|
397
|
+
_logger.debug("In emb shape: %s", W.shape)
|
|
398
|
+
return W
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def out_embeddings(self) -> np.ndarray:
|
|
402
|
+
W = self.out_emb.weight.detach().cpu().numpy() # (V, D)
|
|
403
|
+
_logger.debug("Out emb shape: %s", W.shape)
|
|
404
|
+
return W
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def avg_embeddings(self) -> np.ndarray:
|
|
408
|
+
return 0.5 * (self.in_embeddings + self.out_embeddings)
|
|
166
409
|
|
|
167
410
|
# tiny helper for device move
|
|
168
411
|
def to(self, device):
|
|
@@ -171,7 +414,7 @@ class SGNS_Torch:
|
|
|
171
414
|
return self
|
|
172
415
|
|
|
173
416
|
|
|
174
|
-
__all__ = ["SGNS_Torch"]
|
|
417
|
+
__all__ = ["SGNS_Torch", "SG_Torch"]
|
|
175
418
|
|
|
176
419
|
if __name__ == "__main__":
|
|
177
420
|
pass
|
sawnergy/embedding/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from .embedder import Embedder
|
|
3
|
+
from .embedder import Embedder, align_frames
|
|
4
|
+
from .visualizer import Visualizer
|
|
4
5
|
|
|
5
6
|
def __getattr__(name: str):
|
|
6
7
|
"""Lazily expose optional backends."""
|
|
@@ -14,6 +15,16 @@ def __getattr__(name: str):
|
|
|
14
15
|
) from exc
|
|
15
16
|
return SGNS_Torch
|
|
16
17
|
|
|
18
|
+
if name == "SG_Torch":
|
|
19
|
+
try:
|
|
20
|
+
from .SGNS_torch import SG_Torch
|
|
21
|
+
except Exception as exc:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"PyTorch backend requested but torch is not installed. "
|
|
24
|
+
"Install PyTorch via `pip install torch` (see https://pytorch.org/get-started)."
|
|
25
|
+
) from exc
|
|
26
|
+
return SG_Torch
|
|
27
|
+
|
|
17
28
|
if name == "SGNS_PureML":
|
|
18
29
|
try:
|
|
19
30
|
from .SGNS_pml import SGNS_PureML
|
|
@@ -24,11 +35,25 @@ def __getattr__(name: str):
|
|
|
24
35
|
"Install PureML first via `pip install ym-pure-ml` "
|
|
25
36
|
) from exc
|
|
26
37
|
|
|
38
|
+
if name == "SG_PureML":
|
|
39
|
+
try:
|
|
40
|
+
from .SGNS_pml import SG_PureML
|
|
41
|
+
return SG_PureML
|
|
42
|
+
except Exception as exc:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"PureML is not installed. "
|
|
45
|
+
"Install PureML first via `pip install ym-pure-ml` "
|
|
46
|
+
) from exc
|
|
47
|
+
|
|
27
48
|
raise AttributeError(name)
|
|
28
49
|
|
|
29
50
|
|
|
30
51
|
__all__ = [
|
|
31
52
|
"Embedder",
|
|
53
|
+
"align_frames",
|
|
54
|
+
"Visualizer",
|
|
32
55
|
"SGNS_PureML",
|
|
33
56
|
"SGNS_Torch",
|
|
57
|
+
"SG_PureML",
|
|
58
|
+
"SG_Torch"
|
|
34
59
|
]
|