sawnergy 1.0.4__tar.gz → 1.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

Files changed (31) hide show
  1. {sawnergy-1.0.4/sawnergy.egg-info → sawnergy-1.0.6}/PKG-INFO +1 -1
  2. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/__init__.py +3 -1
  3. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/embedding/SGNS_pml.py +65 -28
  4. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/embedding/embedder.py +9 -3
  5. {sawnergy-1.0.4 → sawnergy-1.0.6/sawnergy.egg-info}/PKG-INFO +1 -1
  6. {sawnergy-1.0.4 → sawnergy-1.0.6}/LICENSE +0 -0
  7. {sawnergy-1.0.4 → sawnergy-1.0.6}/NOTICE +0 -0
  8. {sawnergy-1.0.4 → sawnergy-1.0.6}/README.md +0 -0
  9. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/embedding/SGNS_torch.py +0 -0
  10. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/embedding/__init__.py +0 -0
  11. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/logging_util.py +0 -0
  12. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/rin/__init__.py +0 -0
  13. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/rin/rin_builder.py +0 -0
  14. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/rin/rin_util.py +0 -0
  15. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/sawnergy_util.py +0 -0
  16. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/visual/__init__.py +0 -0
  17. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/visual/visualizer.py +0 -0
  18. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/visual/visualizer_util.py +0 -0
  19. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/walks/__init__.py +0 -0
  20. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/walks/walker.py +0 -0
  21. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy/walks/walker_util.py +0 -0
  22. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy.egg-info/SOURCES.txt +0 -0
  23. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy.egg-info/dependency_links.txt +0 -0
  24. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy.egg-info/requires.txt +0 -0
  25. {sawnergy-1.0.4 → sawnergy-1.0.6}/sawnergy.egg-info/top_level.txt +0 -0
  26. {sawnergy-1.0.4 → sawnergy-1.0.6}/setup.cfg +0 -0
  27. {sawnergy-1.0.4 → sawnergy-1.0.6}/tests/test_embedding.py +0 -0
  28. {sawnergy-1.0.4 → sawnergy-1.0.6}/tests/test_rin.py +0 -0
  29. {sawnergy-1.0.4 → sawnergy-1.0.6}/tests/test_storage.py +0 -0
  30. {sawnergy-1.0.4 → sawnergy-1.0.6}/tests/test_visual.py +0 -0
  31. {sawnergy-1.0.4 → sawnergy-1.0.6}/tests/test_walks.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sawnergy
3
- Version: 1.0.4
3
+ Version: 1.0.6
4
4
  Summary: Toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations
5
5
  Home-page: https://github.com/Yehor-Mishchyriak/SAWNERGY
6
6
  Author: Yehor Mishchyriak
@@ -3,11 +3,13 @@ from . import logging_util
3
3
  from . import rin
4
4
  from . import visual
5
5
  from . import walks
6
+ from . import embedding
6
7
 
7
8
  __all__ = [
8
9
  "sawnergy_util",
9
10
  "logging_util",
10
11
  "rin",
11
12
  "visual",
12
- "walks"
13
+ "walks",
14
+ "embedding"
13
15
  ]
@@ -34,45 +34,76 @@ class SGNS_PureML(NN):
34
34
  seed: int | None = None,
35
35
  optim: Type[Optim],
36
36
  optim_kwargs: dict,
37
- lr_sched: Type[LRScheduler],
38
- lr_sched_kwargs: dict):
37
+ lr_sched: Type[LRScheduler] | None = None,
38
+ lr_sched_kwargs: dict | None = None,
39
+ device: str | None = None):
39
40
  """
40
41
  Args:
41
42
  V: Vocabulary size (number of nodes).
42
43
  D: Embedding dimensionality.
43
44
  seed: Optional RNG seed for negative sampling.
44
- optim: PureML optimizer class.
45
- optim_kwargs: Keyword arguments forwarded to the optimizer.
46
- lr_sched: PureML learning-rate scheduler class.
47
- lr_sched_kwargs: Keyword arguments forwarded to the scheduler.
45
+ optim: Optimizer class to instantiate.
46
+ optim_kwargs: Keyword arguments for the optimizer (required).
47
+ lr_sched: Optional learning-rate scheduler class.
48
+ lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
49
+ device: Target device string (e.g. "cuda"); accepted for API parity, ignored by PureML.
48
50
  """
51
+
52
+ if optim_kwargs is None:
53
+ raise ValueError("optim_kwargs must be provided")
54
+ if lr_sched is not None and lr_sched_kwargs is None:
55
+ raise ValueError("lr_sched_kwargs required when lr_sched is provided")
56
+
49
57
  self.V, self.D = int(V), int(D)
50
- self.in_emb = Embedding(V, D)
51
- self.out_emb = Embedding(V, D)
52
58
 
59
+ # embeddings
60
+ self.in_emb = Embedding(self.V, self.D)
61
+ self.out_emb = Embedding(self.V, self.D)
62
+
63
+ # seed + RNG for negative sampling
53
64
  self.seed = None if seed is None else int(seed)
54
65
  self._rng = np.random.default_rng(self.seed)
66
+ if self.seed is not None:
67
+ # optional: also set global NumPy seed for any non-RNG paths
68
+ np.random.seed(self.seed)
69
+
70
+ # API compatibility: PureML is CPU-only
71
+ self.device = "cpu"
72
+
73
+ # optimizer / scheduler
74
+ self.optim: Optim = optim(self.parameters, **optim_kwargs)
75
+ self.lr_sched: LRScheduler | None = (
76
+ lr_sched(optim=self.optim, **lr_sched_kwargs) if lr_sched is not None else None
77
+ )
55
78
 
56
- self.optim: Optim = optim(self.parameters, **optim_kwargs)
57
- self.lr_sched: LRScheduler = lr_sched(**lr_sched_kwargs)
58
- _logger.info("SGNS_PureML init: V=%d D=%d seed=%s", self.V, self.D, self.seed)
79
+ _logger.info(
80
+ "SGNS_PureML init: V=%d D=%d device=%s seed=%s",
81
+ self.V, self.D, self.device, self.seed
82
+ )
59
83
 
60
- def _sample_neg(self, B: int, K: int, dist: np.ndarray):
84
+ def _sample_neg(self, B: int, K: int, dist: np.ndarray) -> np.ndarray:
61
85
  """Draw negative samples according to the provided unigram distribution."""
62
86
  if dist.ndim != 1 or dist.size != self.V:
63
87
  raise ValueError(f"noise_dist must be 1-D with length {self.V}; got {dist.shape}")
64
88
  return self._rng.choice(self.V, size=(B, K), replace=True, p=dist)
65
89
 
66
- def predict(self, center: Tensor, pos: Tensor, neg: Tensor) -> Tensor:
67
- """Compute positive/negative logits for SGNS."""
68
- c = self.in_emb(center)
69
- pos_e = self.out_emb(pos)
70
- neg_e = self.out_emb(neg)
71
- pos_logits = t_sum(c * pos_e, axis=-1)
72
- neg_logits = t_sum(c[:, None, :] * neg_e, axis=-1)
73
- # ^^^
74
- # (B,1,D) * (B,K,D) → (B,K,D) → sum D → (B,K)
90
+ def predict(self, center: Tensor, pos: Tensor, neg: Tensor) -> tuple[Tensor, Tensor]:
91
+ """Compute positive/negative logits for SGNS.
92
+
93
+ Shapes:
94
+ center: (B,)
95
+ pos: (B,)
96
+ neg: (B, K)
97
+ Returns:
98
+ pos_logits: (B,)
99
+ neg_logits: (B, K)
100
+ """
101
+ c = self.in_emb(center) # (B, D)
102
+ pos_e = self.out_emb(pos) # (B, D)
103
+ neg_e = self.out_emb(neg) # (B, K, D)
75
104
 
105
+ pos_logits = t_sum(c * pos_e, axis=-1) # (B,)
106
+ neg_logits = t_sum(c[:, None, :] * neg_e, axis=-1) # (B, K)
76
107
  return pos_logits, neg_logits
77
108
 
78
109
  def fit(self,
@@ -94,29 +125,35 @@ class SGNS_PureML(NN):
94
125
  for epoch in range(1, num_epochs + 1):
95
126
  epoch_loss = 0.0
96
127
  batches = 0
128
+
97
129
  for cen, pos in DataLoader(data, batch_size=batch_size, shuffle=shuffle_data):
98
- neg = self._sample_neg(batch_size, num_negative_samples, noise_dist)
130
+ B = cen.data.shape[0] if isinstance(cen, Tensor) else len(cen)
99
131
 
132
+ neg_idx_np = self._sample_neg(B, num_negative_samples, noise_dist)
133
+ neg = Tensor(neg_idx_np, requires_grad=False)
100
134
  x_pos_logits, x_neg_logits = self(cen, pos, neg)
101
135
 
102
136
  y_pos = Tensor(np.ones_like(x_pos_logits.data))
103
137
  y_neg = Tensor(np.zeros_like(x_neg_logits.data))
104
138
 
105
- loss = BCE(y_pos, x_pos_logits, from_logits=True) + BCE(y_neg, x_neg_logits, from_logits=True)
139
+ loss = (
140
+ BCE(y_pos, x_pos_logits, from_logits=True)
141
+ + BCE(y_neg, x_neg_logits, from_logits=True)
142
+ )
106
143
 
107
144
  self.optim.zero_grad()
108
145
  loss.backward()
109
146
  self.optim.step()
110
-
111
- if lr_step_per_batch:
147
+
148
+ if lr_step_per_batch and self.lr_sched is not None:
112
149
  self.lr_sched.step()
113
150
 
114
- loss_value = float(np.asarray(loss.data).mean())
151
+ loss_value = float(np.asarray(loss.data))
115
152
  epoch_loss += loss_value
116
153
  batches += 1
117
154
  _logger.debug("Epoch %d batch %d loss=%.6f", epoch, batches, loss_value)
118
155
 
119
- if not lr_step_per_batch:
156
+ if (not lr_step_per_batch) and (self.lr_sched is not None):
120
157
  self.lr_sched.step()
121
158
 
122
159
  mean_loss = epoch_loss / max(batches, 1)
@@ -124,7 +161,7 @@ class SGNS_PureML(NN):
124
161
 
125
162
  @property
126
163
  def embeddings(self) -> np.ndarray:
127
- """Return the input embedding matrix as a NumPy array."""
164
+ """Return the input embedding matrix as a NumPy array (V, D)."""
128
165
  W: Tensor = self.in_emb.parameters[0]
129
166
  return np.asarray(W.data)
130
167
 
@@ -330,6 +330,7 @@ class Embedder:
330
330
  num_epochs: int,
331
331
  batch_size: int,
332
332
  *,
333
+ lr_step_per_batch: bool = False,
333
334
  shuffle_data: bool = True,
334
335
  dimensionality: int = 128,
335
336
  alpha: float = 0.75,
@@ -355,7 +356,8 @@ class Embedder:
355
356
  device: Optional device string for the Torch backend (e.g., ``"cuda"``).
356
357
  sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
357
358
  constructor. For PureML, required keys are:
358
- ``{"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}``.
359
+ ``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
360
+ provided then ``lr_sched_kwargs`` must also be provided.
359
361
  _seed: Optional child seed for this frame's model initialization.
360
362
 
361
363
  Returns:
@@ -391,10 +393,14 @@ class Embedder:
391
393
 
392
394
  model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
393
395
  if self.model_base == "pureml":
394
- required = {"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}
396
+ required = {"optim", "optim_kwargs"}
395
397
  missing = required.difference(model_kwargs)
396
398
  if missing:
397
399
  raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
400
+ has_sched = ("lr_sched" in model_kwargs and model_kwargs["lr_sched"] is not None)
401
+ has_sched_kwargs = ("lr_sched_kwargs" in model_kwargs and model_kwargs["lr_sched_kwargs"] is not None)
402
+ if has_sched and not has_sched_kwargs:
403
+ raise ValueError("When providing lr_sched for PureML, you must also provide lr_sched_kwargs.")
398
404
 
399
405
  child_seed = int(self._seed if _seed is None else _seed)
400
406
  model_kwargs.update({
@@ -429,7 +435,7 @@ class Embedder:
429
435
  num_negative_samples,
430
436
  noise_probs,
431
437
  shuffle_data,
432
- lr_step_per_batch=False
438
+ lr_step_per_batch
433
439
  )
434
440
 
435
441
  embeddings = getattr(self.model, "embeddings", None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sawnergy
3
- Version: 1.0.4
3
+ Version: 1.0.6
4
4
  Summary: Toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations
5
5
  Home-page: https://github.com/Yehor-Mishchyriak/SAWNERGY
6
6
  Author: Yehor Mishchyriak
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes