sawnergy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/__init__.py +13 -0
- sawnergy/embedding/SGNS_pml.py +135 -0
- sawnergy/embedding/SGNS_torch.py +177 -0
- sawnergy/embedding/__init__.py +34 -0
- sawnergy/embedding/embedder.py +578 -0
- sawnergy/logging_util.py +54 -0
- sawnergy/rin/__init__.py +9 -0
- sawnergy/rin/rin_builder.py +936 -0
- sawnergy/rin/rin_util.py +391 -0
- sawnergy/sawnergy_util.py +1182 -0
- sawnergy/visual/__init__.py +42 -0
- sawnergy/visual/visualizer.py +690 -0
- sawnergy/visual/visualizer_util.py +387 -0
- sawnergy/walks/__init__.py +16 -0
- sawnergy/walks/walker.py +795 -0
- sawnergy/walks/walker_util.py +384 -0
- sawnergy-1.0.0.dist-info/METADATA +290 -0
- sawnergy-1.0.0.dist-info/RECORD +22 -0
- sawnergy-1.0.0.dist-info/WHEEL +5 -0
- sawnergy-1.0.0.dist-info/licenses/LICENSE +201 -0
- sawnergy-1.0.0.dist-info/licenses/NOTICE +4 -0
- sawnergy-1.0.0.dist-info/top_level.txt +1 -0
sawnergy/walks/walker.py
ADDED
|
@@ -0,0 +1,795 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third-pary
|
|
4
|
+
import numpy as np
|
|
5
|
+
# built-in
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal
|
|
8
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
# local
|
|
12
|
+
from . import walker_util
|
|
13
|
+
from .. import sawnergy_util
|
|
14
|
+
|
|
15
|
+
# *----------------------------------------------------*
|
|
16
|
+
# GLOBALS
|
|
17
|
+
# *----------------------------------------------------*
|
|
18
|
+
|
|
19
|
+
_logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# *----------------------------------------------------*
|
|
22
|
+
# CLASSES
|
|
23
|
+
# *----------------------------------------------------*
|
|
24
|
+
|
|
25
|
+
class Walker:
|
|
26
|
+
"""Random-walk sampler over time-indexed **transition** matrices.
|
|
27
|
+
|
|
28
|
+
Loads per-timestamp stacks of attractive/repulsive **transition** matrices
|
|
29
|
+
(shape ``(T, N, N)``) previously written by the RIN builder, exposes
|
|
30
|
+
sampling for random walks (RW) and self-avoiding walks (SAW), and can
|
|
31
|
+
optionally advance a *time* coordinate using cosine-similarity between
|
|
32
|
+
transition slices (time-aware walks).
|
|
33
|
+
|
|
34
|
+
Matrices live in OS shared memory via :class:`walker_util.SharedNDArray`,
|
|
35
|
+
so multiple processes can read the same buffers zero-copy. Each `Walker`
|
|
36
|
+
instance owns a dedicated :class:`numpy.random.Generator` seeded from a
|
|
37
|
+
master seed for reproducibility.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self,
|
|
41
|
+
RIN_path: str | Path,
|
|
42
|
+
*,
|
|
43
|
+
seed: int | None = None) -> None:
|
|
44
|
+
"""Initialize shared matrices and RNG.
|
|
45
|
+
|
|
46
|
+
Data source:
|
|
47
|
+
Transition dataset names are resolved from the archive attributes
|
|
48
|
+
``'attractive_transitions_name'`` and ``'repulsive_transitions_name'``.
|
|
49
|
+
If either attribute is ``None``, that channel is unavailable.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
RIN_path: Path to an ``ArrayStorage`` archive (.zip) containing
|
|
53
|
+
**transition** matrices written by the builder.
|
|
54
|
+
seed: Optional master seed for this instance's RNG. If ``None``,
|
|
55
|
+
a random 32-bit seed is chosen.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If no matrices are found or arrays are not rank-3.
|
|
59
|
+
RuntimeError: If attractive/repulsive shapes differ or matrices
|
|
60
|
+
are not square along the last two axes.
|
|
61
|
+
"""
|
|
62
|
+
_logger.info("Initializing Walker from %s", RIN_path)
|
|
63
|
+
|
|
64
|
+
# Load numpy arrays from read-only storage
|
|
65
|
+
with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
|
|
66
|
+
attr_name = storage.get_attr("attractive_transitions_name")
|
|
67
|
+
repuls_name = storage.get_attr("repulsive_transitions_name")
|
|
68
|
+
attr_matrices : np.ndarray | None = (
|
|
69
|
+
storage.read(attr_name, slice(None)) if attr_name is not None else None
|
|
70
|
+
)
|
|
71
|
+
repuls_matrices: np.ndarray | None = (
|
|
72
|
+
storage.read(repuls_name, slice(None)) if repuls_name is not None else None
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
_logger.debug(
|
|
76
|
+
"Loaded matrices | attr: shape=%s dtype=%s | repuls: shape=%s dtype=%s",
|
|
77
|
+
getattr(attr_matrices, "shape", None), getattr(attr_matrices, "dtype", None),
|
|
78
|
+
getattr(repuls_matrices, "shape", None), getattr(repuls_matrices, "dtype", None),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Shape & consistency checks (expect (T, N, N))
|
|
82
|
+
if (attr_matrices is not None) and (repuls_matrices is not None):
|
|
83
|
+
if attr_matrices.ndim != 3 or repuls_matrices.ndim != 3:
|
|
84
|
+
_logger.error(
|
|
85
|
+
"Bad ranks: attr.ndim=%s repuls.ndim=%s; expected both 3",
|
|
86
|
+
getattr(attr_matrices, "ndim", None),
|
|
87
|
+
getattr(repuls_matrices, "ndim", None),
|
|
88
|
+
)
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Expected (T,N,N) arrays; got {getattr(attr_matrices, 'shape', None)} "
|
|
91
|
+
f"and {getattr(repuls_matrices, 'shape', None)}"
|
|
92
|
+
)
|
|
93
|
+
if attr_matrices.shape != repuls_matrices.shape:
|
|
94
|
+
_logger.error("Shape mismatch: attr=%s repuls=%s",
|
|
95
|
+
attr_matrices.shape, repuls_matrices.shape)
|
|
96
|
+
raise RuntimeError(
|
|
97
|
+
f"ATTR/REPULS shapes must match exactly; got {attr_matrices.shape} vs {repuls_matrices.shape}"
|
|
98
|
+
)
|
|
99
|
+
T, N1, N2 = attr_matrices.shape
|
|
100
|
+
elif attr_matrices is not None:
|
|
101
|
+
if attr_matrices.ndim != 3:
|
|
102
|
+
raise ValueError(f"Expected (T,N,N); got {attr_matrices.shape}")
|
|
103
|
+
T, N1, N2 = attr_matrices.shape
|
|
104
|
+
elif repuls_matrices is not None:
|
|
105
|
+
if repuls_matrices.ndim != 3:
|
|
106
|
+
raise ValueError(f"Expected (T,N,N); got {repuls_matrices.shape}")
|
|
107
|
+
T, N1, N2 = repuls_matrices.shape
|
|
108
|
+
else:
|
|
109
|
+
_logger.error("No transition matrices detected in %s", RIN_path)
|
|
110
|
+
raise ValueError("No transition matrices detected.")
|
|
111
|
+
|
|
112
|
+
if N1 != N2:
|
|
113
|
+
_logger.error("Non-square matrices along last two dims: (%s, %s)", N1, N2)
|
|
114
|
+
raise RuntimeError(
|
|
115
|
+
f"Transition matrices must be square along last two dims; got ({N1}, {N2})"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
_logger.info("Transition stack OK: T=%d, N=%d", T, N1)
|
|
119
|
+
|
|
120
|
+
# SHARED MEMORY ELEMENTS (read-only default views; fancy indexing via .array)
|
|
121
|
+
self.attr_matrices: walker_util.SharedNDArray | None = (
|
|
122
|
+
walker_util.SharedNDArray.create(
|
|
123
|
+
shape=attr_matrices.shape,
|
|
124
|
+
dtype=attr_matrices.dtype,
|
|
125
|
+
from_array=attr_matrices,
|
|
126
|
+
) if attr_matrices is not None else None
|
|
127
|
+
)
|
|
128
|
+
self.repuls_matrices: walker_util.SharedNDArray | None = (
|
|
129
|
+
walker_util.SharedNDArray.create(
|
|
130
|
+
shape=repuls_matrices.shape,
|
|
131
|
+
dtype=repuls_matrices.dtype,
|
|
132
|
+
from_array=repuls_matrices,
|
|
133
|
+
) if repuls_matrices is not None else None
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
_logger.debug(
|
|
137
|
+
"SharedNDArray created | attr name=%r; repuls name=%r",
|
|
138
|
+
getattr(self.attr_matrices, "name", None),
|
|
139
|
+
getattr(self.repuls_matrices, "name", None),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# AUXILIARY NETWORK INFORMATION
|
|
143
|
+
self.time_stamp_count = T
|
|
144
|
+
self.node_count = N1
|
|
145
|
+
|
|
146
|
+
# NETWORK ELEMENT
|
|
147
|
+
self.nodes = np.arange(0, self.node_count, 1, np.intp)
|
|
148
|
+
self.time_stamps = np.arange(0, self.time_stamp_count, 1, np.intp)
|
|
149
|
+
_logger.debug(
|
|
150
|
+
"Index arrays built: nodes=%d, time_stamps=%d",
|
|
151
|
+
self.nodes.size, self.time_stamps.size
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# INTERNAL
|
|
155
|
+
self._memory_cleaned_up: bool = False
|
|
156
|
+
self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
|
|
157
|
+
self.rng = np.random.default_rng(self._seed)
|
|
158
|
+
_logger.info("RNG initialized (master seed=%d)", self._seed)
|
|
159
|
+
|
|
160
|
+
# explicit resource cleanup
|
|
161
|
+
def close(self) -> None:
|
|
162
|
+
"""Close shared-memory handles and (in main process) unlink segments.
|
|
163
|
+
|
|
164
|
+
Idempotent: if cleanup already occurred, returns immediately. Always
|
|
165
|
+
closes local handles in the current process. If the caller is the main
|
|
166
|
+
process (per ``sawnergy_util.is_main_process()``), also attempts to
|
|
167
|
+
unlink the underlying shared-memory segments (best-effort; suppresses
|
|
168
|
+
``FileNotFoundError`` if already unlinked elsewhere).
|
|
169
|
+
"""
|
|
170
|
+
if self._memory_cleaned_up:
|
|
171
|
+
_logger.debug("close(): already cleaned up; returning")
|
|
172
|
+
return
|
|
173
|
+
_logger.debug("Closing Walker resources (is_main=%s)", sawnergy_util.is_main_process())
|
|
174
|
+
try:
|
|
175
|
+
if self.attr_matrices is not None:
|
|
176
|
+
self.attr_matrices.close()
|
|
177
|
+
if self.repuls_matrices is not None:
|
|
178
|
+
self.repuls_matrices.close()
|
|
179
|
+
_logger.debug("SharedNDArray handles closed")
|
|
180
|
+
if sawnergy_util.is_main_process():
|
|
181
|
+
_logger.debug("Attempting to unlink shared memory segments (main process)")
|
|
182
|
+
try:
|
|
183
|
+
if self.attr_matrices is not None:
|
|
184
|
+
self.attr_matrices.unlink()
|
|
185
|
+
except FileNotFoundError:
|
|
186
|
+
_logger.warning("attr SharedMemory already unlinked")
|
|
187
|
+
try:
|
|
188
|
+
if self.repuls_matrices is not None:
|
|
189
|
+
self.repuls_matrices.unlink()
|
|
190
|
+
except FileNotFoundError:
|
|
191
|
+
_logger.warning("repuls SharedMemory already unlinked")
|
|
192
|
+
else:
|
|
193
|
+
_logger.debug("Not main process; skipping unlink")
|
|
194
|
+
finally:
|
|
195
|
+
self._memory_cleaned_up = True
|
|
196
|
+
_logger.debug("Cleanup complete")
|
|
197
|
+
|
|
198
|
+
def __enter__(self):
|
|
199
|
+
"""Enter context manager scope."""
|
|
200
|
+
_logger.debug("__enter__")
|
|
201
|
+
return self
|
|
202
|
+
|
|
203
|
+
def __exit__(self, exc_type, exc, tb):
|
|
204
|
+
"""Exit context manager scope and perform cleanup."""
|
|
205
|
+
_logger.debug("__exit__(exc_type=%s)", getattr(exc_type, "__name__", exc_type))
|
|
206
|
+
self.close()
|
|
207
|
+
|
|
208
|
+
def __del__(self):
|
|
209
|
+
"""Best-effort destructor cleanup (exceptions suppressed)."""
|
|
210
|
+
try:
|
|
211
|
+
if not getattr(self, "_memory_cleaned_up", True):
|
|
212
|
+
_logger.debug("__del__: best-effort close")
|
|
213
|
+
self.close()
|
|
214
|
+
except Exception as e:
|
|
215
|
+
_logger.debug("__del__ suppressed exception: %r", e)
|
|
216
|
+
|
|
217
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
218
|
+
# PRIVATE
|
|
219
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
220
|
+
|
|
221
|
+
def _matrices_of_interaction_type(
|
|
222
|
+
self, interaction_type: Literal["attr", "repuls"]
|
|
223
|
+
) -> walker_util.SharedNDArray:
|
|
224
|
+
"""Return the shared array wrapper for the requested interaction type.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
interaction_type: Either ``"attr"`` or ``"repuls"``.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
SharedNDArray wrapper exposing the ``(T, N, N)`` stack.
|
|
231
|
+
|
|
232
|
+
Raises:
|
|
233
|
+
ValueError: If the channel is missing or the type is invalid.
|
|
234
|
+
"""
|
|
235
|
+
_logger.debug("_matrices_of_interaction_type(%s)", interaction_type)
|
|
236
|
+
if interaction_type == "attr":
|
|
237
|
+
if self.attr_matrices is None:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"Attractive transition matrices are not present in the RIN archive."
|
|
240
|
+
)
|
|
241
|
+
return self.attr_matrices
|
|
242
|
+
if interaction_type == "repuls":
|
|
243
|
+
if self.repuls_matrices is None:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
"Repulsive transition matrices are not present in the RIN archive."
|
|
246
|
+
)
|
|
247
|
+
return self.repuls_matrices
|
|
248
|
+
_logger.error("interaction_type invalid: %r", interaction_type)
|
|
249
|
+
raise ValueError("`interaction_type` must be 'attr' or 'repuls'.")
|
|
250
|
+
|
|
251
|
+
def _extract_prob_vector(
|
|
252
|
+
self,
|
|
253
|
+
node: int,
|
|
254
|
+
time_stamp: int,
|
|
255
|
+
interaction_type: Literal["attr", "repuls"],
|
|
256
|
+
) -> np.ndarray:
|
|
257
|
+
"""Copy the transition row for ``node`` at ``time_stamp``.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
``(N,)`` float array with transition probabilities/weights.
|
|
261
|
+
"""
|
|
262
|
+
_logger.debug("_extract_prob_vector(node=%d, t=%d, type=%s)",
|
|
263
|
+
node, time_stamp, interaction_type)
|
|
264
|
+
matrix = self._matrices_of_interaction_type(interaction_type)[time_stamp]
|
|
265
|
+
vec = matrix[node, :].copy() # detach from shared buffer
|
|
266
|
+
_logger.debug("prob vector extracted: shape=%s dtype=%s", vec.shape, vec.dtype)
|
|
267
|
+
return vec
|
|
268
|
+
|
|
269
|
+
def _step_node(
|
|
270
|
+
self,
|
|
271
|
+
node: int,
|
|
272
|
+
interaction_type: Literal["attr", "repuls"],
|
|
273
|
+
time_stamp: int = 0,
|
|
274
|
+
avoid: np.typing.ArrayLike | None = None,
|
|
275
|
+
) -> tuple[int, np.ndarray | None]:
|
|
276
|
+
"""Sample next node given current node and optional avoidance set.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
``(next_node, updated_avoid)`` where ``updated_avoid`` is ``None`` if
|
|
280
|
+
avoidance is disabled.
|
|
281
|
+
"""
|
|
282
|
+
_logger.debug(
|
|
283
|
+
"_step_node(node=%d, t=%d, type=%s, avoid_len=%s)",
|
|
284
|
+
node, time_stamp, interaction_type,
|
|
285
|
+
None if avoid is None else np.asarray(avoid).size,
|
|
286
|
+
)
|
|
287
|
+
prob_dist = self._extract_prob_vector(node, time_stamp, interaction_type)
|
|
288
|
+
|
|
289
|
+
if avoid is None:
|
|
290
|
+
mass = float(np.sum(prob_dist))
|
|
291
|
+
_logger.debug("_step_node: no-avoid branch; mass=%.6f", mass)
|
|
292
|
+
if mass <= 0.0:
|
|
293
|
+
_logger.error("_step_node: zero probability mass without avoidance")
|
|
294
|
+
raise RuntimeError("No valid node transitions: zero probability mass.")
|
|
295
|
+
return int(self.rng.choice(self.nodes, p=prob_dist)), None
|
|
296
|
+
|
|
297
|
+
to_avoid = np.asarray(avoid, dtype=np.intp)
|
|
298
|
+
keep = np.setdiff1d(self.nodes, to_avoid, assume_unique=False)
|
|
299
|
+
_logger.debug("_step_node: keep.size=%d (after removing %d avoids)",
|
|
300
|
+
keep.size, to_avoid.size)
|
|
301
|
+
if keep.size == 0:
|
|
302
|
+
_logger.error("_step_node: empty candidate set after avoidance")
|
|
303
|
+
raise RuntimeError("No available node transitions (avoiding all nodes).")
|
|
304
|
+
|
|
305
|
+
probs = walker_util.l1_norm(prob_dist[keep])
|
|
306
|
+
mass = float(probs.sum())
|
|
307
|
+
_logger.debug("_step_node: normalized mass=%.6f", mass)
|
|
308
|
+
if mass <= 0.0:
|
|
309
|
+
_logger.error("_step_node: zero probability mass after masking/normalization")
|
|
310
|
+
raise RuntimeError("No valid node transitions: probability mass is zero.")
|
|
311
|
+
|
|
312
|
+
next_node = int(self.rng.choice(keep, p=probs))
|
|
313
|
+
_logger.debug("_step_node: chosen next_node=%d", next_node)
|
|
314
|
+
to_avoid = np.append(to_avoid, next_node).astype(np.intp, copy=False)
|
|
315
|
+
|
|
316
|
+
return next_node, to_avoid
|
|
317
|
+
|
|
318
|
+
def _step_time(
|
|
319
|
+
self,
|
|
320
|
+
time_stamp: int,
|
|
321
|
+
interaction_type: Literal["attr", "repuls"],
|
|
322
|
+
stickiness: float,
|
|
323
|
+
on_no_options: Literal["raise", "loop"],
|
|
324
|
+
avoid: np.typing.ArrayLike | None,
|
|
325
|
+
) -> tuple[int, np.ndarray | None]:
|
|
326
|
+
"""Sample next time stamp given stickiness and similarity.
|
|
327
|
+
|
|
328
|
+
Raises:
|
|
329
|
+
ValueError: If ``stickiness`` not in ``[0,1]`` or ``on_no_options`` invalid.
|
|
330
|
+
RuntimeError: If no candidates or zero probability mass.
|
|
331
|
+
"""
|
|
332
|
+
_logger.debug(
|
|
333
|
+
"_step_time(t=%d, type=%s, stickiness=%.3f, on_no_options=%s, avoid_len=%s)",
|
|
334
|
+
time_stamp, interaction_type, stickiness, on_no_options,
|
|
335
|
+
None if avoid is None else np.asarray(avoid).size,
|
|
336
|
+
)
|
|
337
|
+
if not (0.0 <= stickiness <= 1.0):
|
|
338
|
+
_logger.error("stickiness out of range: %r", stickiness)
|
|
339
|
+
raise ValueError("stickiness must be in [0,1]")
|
|
340
|
+
|
|
341
|
+
to_avoid = np.array([], dtype=np.intp) if avoid is None else np.asarray(avoid, dtype=np.intp)
|
|
342
|
+
|
|
343
|
+
# With probability `stickiness`, remain at the same time stamp
|
|
344
|
+
r = float(self.rng.random())
|
|
345
|
+
_logger.debug("_step_time: rand=%.6f vs stickiness=%.6f", r, float(stickiness))
|
|
346
|
+
if r < float(stickiness):
|
|
347
|
+
_logger.debug("_step_time: sticking at t=%d", time_stamp)
|
|
348
|
+
return int(time_stamp), to_avoid
|
|
349
|
+
|
|
350
|
+
# Exclude current time since we chose not to stick
|
|
351
|
+
to_avoid = np.unique(np.append(to_avoid, time_stamp).astype(np.intp, copy=False))
|
|
352
|
+
keep = np.setdiff1d(self.time_stamps, to_avoid, assume_unique=True)
|
|
353
|
+
_logger.debug("_step_time: keep.size=%d (to_avoid.size=%d)", keep.size, to_avoid.size)
|
|
354
|
+
|
|
355
|
+
matrices = self._matrices_of_interaction_type(interaction_type)
|
|
356
|
+
current_matrix = matrices[time_stamp] # axis-0 basic indexing returns a view
|
|
357
|
+
|
|
358
|
+
if keep.size == 0:
|
|
359
|
+
if on_no_options == "raise":
|
|
360
|
+
_logger.error("_step_time: no available timestamps (avoid=%s)", np.unique(to_avoid))
|
|
361
|
+
raise RuntimeError(f"No available time stamps (avoid={np.unique(to_avoid)})")
|
|
362
|
+
if on_no_options == "loop":
|
|
363
|
+
_logger.warning("_step_time: looping over all except current (t=%d)", time_stamp)
|
|
364
|
+
to_avoid = np.array([time_stamp], dtype=np.intp)
|
|
365
|
+
keep = self.time_stamps[self.time_stamps != time_stamp]
|
|
366
|
+
if keep.size == 0:
|
|
367
|
+
_logger.error("_step_time: loop mode impossible (T==1)")
|
|
368
|
+
raise RuntimeError("No alternative time stamps available for loop mode (T==1).")
|
|
369
|
+
matrices_stack = matrices.array[keep] # fancy indexing on ndarray, not wrapper
|
|
370
|
+
else:
|
|
371
|
+
_logger.error("_step_time: invalid on_no_options=%r", on_no_options)
|
|
372
|
+
raise ValueError("on_no_options must be 'raise' or 'loop'")
|
|
373
|
+
else:
|
|
374
|
+
matrices_stack = matrices.array[keep] # fancy indexing on ndarray, not wrapper
|
|
375
|
+
|
|
376
|
+
sims = walker_util.apply_on_axis0(matrices_stack, walker_util.cosine_similarity(current_matrix))
|
|
377
|
+
probs = walker_util.l1_norm(sims)
|
|
378
|
+
mass = float(probs.sum())
|
|
379
|
+
_logger.debug("_step_time: candidates=%d, mass=%.6f", keep.size, mass)
|
|
380
|
+
if mass <= 0.0:
|
|
381
|
+
_logger.error(
|
|
382
|
+
"_step_time: zero probability mass (t=%d, type=%s, candidates=%d)",
|
|
383
|
+
time_stamp, interaction_type, keep.size
|
|
384
|
+
)
|
|
385
|
+
raise RuntimeError(
|
|
386
|
+
"No valid time stamps to sample: probability mass is zero after masking/normalization."
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
next_time_stamp = int(self.rng.choice(keep, p=probs))
|
|
390
|
+
_logger.debug("_step_time: chosen next_time_stamp=%d", next_time_stamp)
|
|
391
|
+
return next_time_stamp, to_avoid
|
|
392
|
+
|
|
393
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
394
|
+
# PUBLIC
|
|
395
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
396
|
+
|
|
397
|
+
def walk(self,
|
|
398
|
+
start_node: int | None,
|
|
399
|
+
start_time_stamp: int | None,
|
|
400
|
+
length: int,
|
|
401
|
+
interaction_type: Literal["attr", "repuls"],
|
|
402
|
+
self_avoid: bool,
|
|
403
|
+
time_aware: bool = False,
|
|
404
|
+
stickiness: float | None = None,
|
|
405
|
+
on_no_options: Literal["raise", "loop"] | None = None) -> np.ndarray:
|
|
406
|
+
"""Generate one walk path.
|
|
407
|
+
|
|
408
|
+
Indexing contract:
|
|
409
|
+
Public API is **1-based** for both nodes and time stamps to match
|
|
410
|
+
residue numbering in biomolecular contexts. Internally everything is
|
|
411
|
+
0-based. The returned path is 1-based.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
start_node: 1-based start node; if ``None``, sampled uniformly.
|
|
415
|
+
start_time_stamp: 1-based start time; if ``None``, sampled uniformly.
|
|
416
|
+
length: Number of transition **steps** to simulate.
|
|
417
|
+
interaction_type: ``"attr"`` or ``"repuls"``.
|
|
418
|
+
self_avoid: If ``True``, path will not revisit nodes within the same walk.
|
|
419
|
+
time_aware: If ``True``, advance time with :meth:`_step_time` each step.
|
|
420
|
+
stickiness: Required when ``time_aware=True``; probability of staying
|
|
421
|
+
at the current time.
|
|
422
|
+
on_no_options: Required when ``time_aware=True``; behavior when no
|
|
423
|
+
time candidates are available (``"raise"`` or ``"loop"``).
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
``(length + 1,)`` array of dtype ``intp`` with **1-based** node indices.
|
|
427
|
+
|
|
428
|
+
Raises:
|
|
429
|
+
ValueError: Bad start indices (after 1-based→0-based) or missing
|
|
430
|
+
time-aware parameters.
|
|
431
|
+
RuntimeError: Propagated from step routines when no valid choices exist.
|
|
432
|
+
"""
|
|
433
|
+
_logger.debug(
|
|
434
|
+
"walk(start_node=%r, start_time_stamp=%r, length=%d, type=%s, self_avoid=%s, time_aware=%s)",
|
|
435
|
+
start_node, start_time_stamp, length, interaction_type, self_avoid, time_aware
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# 1-based external API preserved, so validate ranges after conversion
|
|
439
|
+
if start_node is not None:
|
|
440
|
+
node = int(start_node) - 1
|
|
441
|
+
if not (0 <= node < self.node_count):
|
|
442
|
+
_logger.error("start_node out of range after 1-based conversion: %r", start_node)
|
|
443
|
+
raise ValueError(f"start_node out of range after 1-based conversion: {start_node}")
|
|
444
|
+
else:
|
|
445
|
+
node = int(self.rng.choice(self.nodes))
|
|
446
|
+
|
|
447
|
+
if start_time_stamp is not None:
|
|
448
|
+
time_stamp = int(start_time_stamp) - 1
|
|
449
|
+
if not (0 <= time_stamp < self.time_stamp_count):
|
|
450
|
+
_logger.error("start_time_stamp out of range after 1-based conversion: %r", start_time_stamp)
|
|
451
|
+
raise ValueError(f"start_time_stamp out of range after 1-based conversion: {start_time_stamp}")
|
|
452
|
+
else:
|
|
453
|
+
time_stamp = int(self.rng.choice(self.time_stamps))
|
|
454
|
+
|
|
455
|
+
_logger.debug("walk: initial node=%d, t=%d", node, time_stamp)
|
|
456
|
+
|
|
457
|
+
nodes_to_avoid: np.ndarray | None = np.array([node], dtype=np.intp) if self_avoid else None
|
|
458
|
+
time_stamps_to_avoid: np.ndarray | None = None
|
|
459
|
+
|
|
460
|
+
pth = np.array([node], dtype=np.intp)
|
|
461
|
+
|
|
462
|
+
if time_aware and (stickiness is None or on_no_options is None):
|
|
463
|
+
_logger.error("time_aware=True but stickiness/on_no_options missing")
|
|
464
|
+
raise ValueError("time_aware=True requires both `stickiness` and `on_no_options`.")
|
|
465
|
+
|
|
466
|
+
for _ in range(length):
|
|
467
|
+
if self_avoid:
|
|
468
|
+
node, nodes_to_avoid = self._step_node(node, interaction_type, time_stamp, nodes_to_avoid)
|
|
469
|
+
else:
|
|
470
|
+
node, _ = self._step_node(node, interaction_type, time_stamp, avoid=None)
|
|
471
|
+
pth = np.append(pth, node).astype(np.intp, copy=False)
|
|
472
|
+
|
|
473
|
+
if time_aware:
|
|
474
|
+
time_stamp, time_stamps_to_avoid = self._step_time(
|
|
475
|
+
time_stamp, interaction_type, stickiness, on_no_options, time_stamps_to_avoid
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
pth += 1 # ensure 1-based indexing in the output
|
|
479
|
+
|
|
480
|
+
_logger.debug("walk: finished path of len=%d", pth.size)
|
|
481
|
+
return pth
|
|
482
|
+
|
|
483
|
+
# deterministic per-batch worker: (start_nodes_batch, seedseq/int) -> stack of walks
|
|
484
|
+
def _walk_batch_with_seed(self, work_item, num_walks_from_each: int, *args, **kwargs):
|
|
485
|
+
"""Worker: seed RNG and generate a batch of walks for a set of start nodes.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
work_item: Tuple ``(start_nodes, seed_obj)`` where ``start_nodes`` is
|
|
489
|
+
an iterable of **0-based** node indices and ``seed_obj`` is an
|
|
490
|
+
``np.random.SeedSequence`` or an ``int``.
|
|
491
|
+
num_walks_from_each: Number of walks to generate per start node. If 0,
|
|
492
|
+
returns an empty array with the correct width.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
``(M, L+1)`` array of dtype ``uint16`` with 1-based node indices,
|
|
496
|
+
where ``M = len(start_nodes) * num_walks_from_each`` and ``L`` is
|
|
497
|
+
``kwargs["length"]`` (defaults to 0 if absent). If no walks are
|
|
498
|
+
generated, returns ``(0, L+1)``.
|
|
499
|
+
"""
|
|
500
|
+
start_nodes, seed_obj = work_item
|
|
501
|
+
_logger.debug(
|
|
502
|
+
"_walk_batch_with_seed: batch_size=%d, walks_each=%d",
|
|
503
|
+
np.asarray(start_nodes).size, int(num_walks_from_each)
|
|
504
|
+
)
|
|
505
|
+
self.rng = np.random.default_rng(seed_obj) # SeedSequence or int OK
|
|
506
|
+
start_nodes = np.asarray(start_nodes, dtype=np.intp)
|
|
507
|
+
out = []
|
|
508
|
+
for snode in start_nodes:
|
|
509
|
+
for _ in range(int(num_walks_from_each)):
|
|
510
|
+
out.append(self.walk(int(snode)+1, *args, **kwargs)) # 1-based API
|
|
511
|
+
|
|
512
|
+
if not out:
|
|
513
|
+
L = int(kwargs.get("length", 0))
|
|
514
|
+
return np.empty((0, L + 1), dtype=np.uint16)
|
|
515
|
+
|
|
516
|
+
arr = np.stack(out, axis=0).astype(np.uint16, copy=False)
|
|
517
|
+
_logger.debug("_walk_batch_with_seed: produced walks shape=%s dtype=%s", arr.shape, arr.dtype)
|
|
518
|
+
return arr
|
|
519
|
+
|
|
520
|
+
def _walks_per_time(self,
|
|
521
|
+
work_items,
|
|
522
|
+
processor,
|
|
523
|
+
num_walks_from_each: int,
|
|
524
|
+
*,
|
|
525
|
+
walk_length: int,
|
|
526
|
+
interaction_type: Literal["attr", "repuls"],
|
|
527
|
+
self_avoid: bool,
|
|
528
|
+
time_aware: bool,
|
|
529
|
+
stickiness: float | None,
|
|
530
|
+
on_no_options: Literal["raise", "loop"] | None) -> np.ndarray:
|
|
531
|
+
"""
|
|
532
|
+
Generate walks separately for each start time stamp and stack as (T, M, L+1).
|
|
533
|
+
|
|
534
|
+
- If time_aware=False: each layer t contains walks constrained to time t.
|
|
535
|
+
- If time_aware=True: each layer t contains walks that *start* at t and
|
|
536
|
+
may traverse time via _step_time during the walk.
|
|
537
|
+
"""
|
|
538
|
+
per_time: list[np.ndarray] = []
|
|
539
|
+
for t in range(self.time_stamp_count):
|
|
540
|
+
chunks = processor(
|
|
541
|
+
work_items,
|
|
542
|
+
self._walk_batch_with_seed,
|
|
543
|
+
int(num_walks_from_each),
|
|
544
|
+
start_time_stamp=int(t + 1),
|
|
545
|
+
length=walk_length,
|
|
546
|
+
interaction_type=interaction_type,
|
|
547
|
+
self_avoid=self_avoid,
|
|
548
|
+
time_aware=bool(time_aware),
|
|
549
|
+
stickiness=stickiness,
|
|
550
|
+
on_no_options=on_no_options,
|
|
551
|
+
)
|
|
552
|
+
if chunks:
|
|
553
|
+
all_walks_2d = np.concatenate(chunks, axis=0).astype(np.uint16, copy=False)
|
|
554
|
+
else:
|
|
555
|
+
all_walks_2d = np.empty((0, walk_length + 1), dtype=np.uint16)
|
|
556
|
+
per_time.append(all_walks_2d)
|
|
557
|
+
arr_3d = (np.stack(per_time, axis=0)
|
|
558
|
+
if per_time else np.empty((self.time_stamp_count, 0, walk_length + 1), dtype=np.uint16))
|
|
559
|
+
_logger.info("_walks_per_time: produced (T,M,L+1)=%s for type=%s, self_avoid=%s, time_aware=%s",
|
|
560
|
+
arr_3d.shape, interaction_type, self_avoid, time_aware)
|
|
561
|
+
return arr_3d
|
|
562
|
+
|
|
563
|
+
def sample_walks(self,
|
|
564
|
+
# walks
|
|
565
|
+
walk_length: int,
|
|
566
|
+
walks_per_node: int,
|
|
567
|
+
saw_frac: float = 0.0,
|
|
568
|
+
include_attractive: bool = True,
|
|
569
|
+
include_repulsive: bool = False,
|
|
570
|
+
# time aware params
|
|
571
|
+
time_aware: bool = False,
|
|
572
|
+
stickiness: float | None = None,
|
|
573
|
+
on_no_options: Literal["raise", "loop"] | None = None,
|
|
574
|
+
# output
|
|
575
|
+
output_path: str | Path | None = None,
|
|
576
|
+
*,
|
|
577
|
+
# computation
|
|
578
|
+
in_parallel: bool,
|
|
579
|
+
# storage
|
|
580
|
+
compression_level: int = 3,
|
|
581
|
+
num_walk_matrices_in_compressed_blocks: int | None = None
|
|
582
|
+
) -> str:
|
|
583
|
+
"""Generate and persist random walks for **all** nodes.
|
|
584
|
+
|
|
585
|
+
For each node, produces ``walks_per_node`` paths split between RW and
|
|
586
|
+
SAW according to ``saw_frac``. Optionally enables time-aware stepping.
|
|
587
|
+
|
|
588
|
+
Output layout:
|
|
589
|
+
Walk arrays are written as **3-D** with shape ``(T, M, L+1)``, where:
|
|
590
|
+
- ``T`` is the number of time stamps,
|
|
591
|
+
- ``M`` is the total number of walks produced per layer (sum over requested nodes),
|
|
592
|
+
- ``L+1`` is the path length including the start node.
|
|
593
|
+
If ``time_aware=False``, each layer t contains walks constrained to time t.
|
|
594
|
+
If ``time_aware=True``, each layer t contains walks that **start** at time t but may evolve in time.
|
|
595
|
+
|
|
596
|
+
Results are chunk-written to a new compressed archive.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
walk_length: Number of steps per walk (path length is ``walk_length+1``).
|
|
600
|
+
walks_per_node: Total walks per node (integer).
|
|
601
|
+
saw_frac: Fraction in ``[0,1]`` of per-node walks that are SAWs
|
|
602
|
+
(remainder are RWs).
|
|
603
|
+
include_attractive: If ``True``, generate walks on the attractive channel.
|
|
604
|
+
include_repulsive: If ``True``, generate walks on the repulsive channel.
|
|
605
|
+
time_aware: If ``True``, enable time evolution via :meth:`_step_time`.
|
|
606
|
+
stickiness: Required when ``time_aware=True``; probability of staying
|
|
607
|
+
at the current time step.
|
|
608
|
+
on_no_options: Required when ``time_aware=True``; when no alternative
|
|
609
|
+
time stamps are available, either ``"raise"`` or ``"loop"``.
|
|
610
|
+
output_path: Destination (with or without ``.zip``). Defaults to
|
|
611
|
+
``WALKS_<timestamp>.zip`` in the current working directory.
|
|
612
|
+
in_parallel: Use :class:`ProcessPoolExecutor` to parallelize over
|
|
613
|
+
node batches (requires main-process guard).
|
|
614
|
+
compression_level: Compression level for the output archive.
|
|
615
|
+
num_walk_matrices_in_compressed_blocks: Max number of walk matrices
|
|
616
|
+
per compressed chunk when writing. Defaults to number of batches.
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
String path to the written archive.
|
|
620
|
+
|
|
621
|
+
Raises:
|
|
622
|
+
ValueError: If ``saw_frac`` is outside ``[0,1]``.
|
|
623
|
+
RuntimeError: If run in parallel without a main-process guard, or when
|
|
624
|
+
no valid transitions are available during stepping.
|
|
625
|
+
"""
|
|
626
|
+
_logger.info(
|
|
627
|
+
"sample_walks: L=%d, per_node=%d, saw_frac=%.3f, time_aware=%s, out=%s, "
|
|
628
|
+
"parallel=%s, arrays_per_chunk=%s",
|
|
629
|
+
walk_length, walks_per_node, saw_frac, time_aware, output_path, in_parallel,
|
|
630
|
+
num_walk_matrices_in_compressed_blocks,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
current_time = sawnergy_util.current_time()
|
|
634
|
+
|
|
635
|
+
output_path = Path((output_path or (Path(os.getcwd()) /
|
|
636
|
+
f"WALKS_{current_time}"))).with_suffix(".zip")
|
|
637
|
+
_logger.debug("Output archive path: %s", output_path)
|
|
638
|
+
|
|
639
|
+
if not (0.0 <= saw_frac <= 1.0):
|
|
640
|
+
_logger.error("saw_frac out of range: %r", saw_frac)
|
|
641
|
+
raise ValueError("saw_frac must be in [0, 1]")
|
|
642
|
+
|
|
643
|
+
# Deterministic integer split
|
|
644
|
+
num_SAWs = int(round(walks_per_node * float(saw_frac)))
|
|
645
|
+
num_RWs = int(walks_per_node) - num_SAWs
|
|
646
|
+
_logger.info("Per-node counts: SAWs=%d, RWs=%d", num_SAWs, num_RWs)
|
|
647
|
+
|
|
648
|
+
num_workers = os.cpu_count() or 1
|
|
649
|
+
batch_size_nodes = (num_workers if in_parallel else 1)
|
|
650
|
+
_logger.debug("Workers=%d, batch_size_nodes=%d", num_workers, batch_size_nodes)
|
|
651
|
+
|
|
652
|
+
if in_parallel and not sawnergy_util.is_main_process():
|
|
653
|
+
_logger.error("Process-based parallelism requires main-process guard")
|
|
654
|
+
raise RuntimeError(
|
|
655
|
+
"Process-based parallelism requires running under `if __name__ == '__main__':`."
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
processor = sawnergy_util.elementwise_processor(
|
|
659
|
+
in_parallel=in_parallel,
|
|
660
|
+
Executor=ProcessPoolExecutor,
|
|
661
|
+
max_workers=num_workers,
|
|
662
|
+
capture_output=True
|
|
663
|
+
)
|
|
664
|
+
_logger.debug("elementwise_processor created (parallel=%s, workers=%d)", in_parallel, num_workers)
|
|
665
|
+
|
|
666
|
+
# Pre-build node batches deterministically
|
|
667
|
+
_logger.debug("Building node batches via sawnergy_util.batches_of (batch_size_nodes=%d)", batch_size_nodes)
|
|
668
|
+
node_batches = list(sawnergy_util.batches_of(self.nodes, batch_size=batch_size_nodes))
|
|
669
|
+
_logger.debug("Built %d node batches", len(node_batches))
|
|
670
|
+
|
|
671
|
+
# Derive deterministic child seeds from master seed — stable per batch
|
|
672
|
+
master_ss = np.random.SeedSequence(self._seed)
|
|
673
|
+
child_seeds = master_ss.spawn(len(node_batches))
|
|
674
|
+
work_items = list(zip(node_batches, child_seeds))
|
|
675
|
+
_logger.debug("Prepared %d work_items with child seeds", len(work_items))
|
|
676
|
+
|
|
677
|
+
num_walk_matrices_in_compressed_blocks = (
|
|
678
|
+
num_walk_matrices_in_compressed_blocks or len(node_batches)
|
|
679
|
+
)
|
|
680
|
+
_logger.info("arrays_per_chunk resolved to: %d", num_walk_matrices_in_compressed_blocks)
|
|
681
|
+
|
|
682
|
+
attractive_RWs_name = "ATTRACTIVE_RWs"
|
|
683
|
+
attractive_SAWs_name = "ATTRACTIVE_SAWs"
|
|
684
|
+
repulsive_RWs_name = "REPULSIVE_RWs"
|
|
685
|
+
repulsive_SAWs_name = "REPULSIVE_SAWs"
|
|
686
|
+
|
|
687
|
+
with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level) as storage:
|
|
688
|
+
if include_attractive:
|
|
689
|
+
# --- ATTR RWs ---
|
|
690
|
+
_logger.info("Generating ATTR RWs ...")
|
|
691
|
+
|
|
692
|
+
attr_RWs_3d = self._walks_per_time(
|
|
693
|
+
work_items, processor, num_RWs,
|
|
694
|
+
walk_length=walk_length,
|
|
695
|
+
interaction_type="attr",
|
|
696
|
+
self_avoid=False,
|
|
697
|
+
time_aware=time_aware,
|
|
698
|
+
stickiness=stickiness,
|
|
699
|
+
on_no_options=on_no_options,
|
|
700
|
+
)
|
|
701
|
+
_logger.info("ATTR RWs (per-time): shape=%s", attr_RWs_3d.shape)
|
|
702
|
+
storage.write(
|
|
703
|
+
attr_RWs_3d,
|
|
704
|
+
to_block_named=attractive_RWs_name,
|
|
705
|
+
arrays_per_chunk=num_walk_matrices_in_compressed_blocks
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
# --- ATTR SAWs ---
|
|
709
|
+
_logger.info("Generating ATTR SAWs ...")
|
|
710
|
+
attr_SAWs_3d = self._walks_per_time(
|
|
711
|
+
work_items, processor, num_SAWs,
|
|
712
|
+
walk_length=walk_length,
|
|
713
|
+
interaction_type="attr",
|
|
714
|
+
self_avoid=True,
|
|
715
|
+
time_aware=time_aware,
|
|
716
|
+
stickiness=stickiness,
|
|
717
|
+
on_no_options=on_no_options,
|
|
718
|
+
)
|
|
719
|
+
_logger.info("ATTR SAWs (per-time): shape=%s", attr_SAWs_3d.shape)
|
|
720
|
+
storage.write(
|
|
721
|
+
attr_SAWs_3d,
|
|
722
|
+
to_block_named=attractive_SAWs_name,
|
|
723
|
+
arrays_per_chunk=num_walk_matrices_in_compressed_blocks
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
if include_repulsive:
|
|
727
|
+
# --- REPULS RWs ---
|
|
728
|
+
_logger.info("Generating REPULS RWs ...")
|
|
729
|
+
repuls_RWs_3d = self._walks_per_time(
|
|
730
|
+
work_items, processor, num_RWs,
|
|
731
|
+
walk_length=walk_length,
|
|
732
|
+
interaction_type="repuls",
|
|
733
|
+
self_avoid=False,
|
|
734
|
+
time_aware=time_aware,
|
|
735
|
+
stickiness=stickiness,
|
|
736
|
+
on_no_options=on_no_options,
|
|
737
|
+
)
|
|
738
|
+
_logger.info("REPULS RWs (per-time): shape=%s", repuls_RWs_3d.shape)
|
|
739
|
+
storage.write(
|
|
740
|
+
repuls_RWs_3d,
|
|
741
|
+
to_block_named=repulsive_RWs_name,
|
|
742
|
+
arrays_per_chunk=num_walk_matrices_in_compressed_blocks
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# --- REPULS SAWs ---
|
|
746
|
+
_logger.info("Generating REPULS SAWs ...")
|
|
747
|
+
repuls_SAWs_3d = self._walks_per_time(
|
|
748
|
+
work_items, processor, num_SAWs,
|
|
749
|
+
walk_length=walk_length,
|
|
750
|
+
interaction_type="repuls",
|
|
751
|
+
self_avoid=True,
|
|
752
|
+
time_aware=time_aware,
|
|
753
|
+
stickiness=stickiness,
|
|
754
|
+
on_no_options=on_no_options,
|
|
755
|
+
)
|
|
756
|
+
_logger.info("REPULS SAWs (per-time): shape=%s", repuls_SAWs_3d.shape)
|
|
757
|
+
storage.write(
|
|
758
|
+
repuls_SAWs_3d,
|
|
759
|
+
to_block_named=repulsive_SAWs_name,
|
|
760
|
+
arrays_per_chunk=num_walk_matrices_in_compressed_blocks
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
# useful metadata
|
|
764
|
+
storage.add_attr("time_created", current_time)
|
|
765
|
+
storage.add_attr("seed", int(self._seed))
|
|
766
|
+
storage.add_attr("rng_scheme", "SeedSequence.spawn_per_batch_v1")
|
|
767
|
+
storage.add_attr("num_workers", int(num_workers))
|
|
768
|
+
storage.add_attr("in_parallel", bool(in_parallel))
|
|
769
|
+
storage.add_attr("batch_size_nodes", int(batch_size_nodes))
|
|
770
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
|
771
|
+
storage.add_attr("num_RWs", num_RWs)
|
|
772
|
+
storage.add_attr("num_SAWs", num_SAWs)
|
|
773
|
+
storage.add_attr("node_count", self.node_count)
|
|
774
|
+
storage.add_attr("walk_length", walk_length)
|
|
775
|
+
storage.add_attr("walks_per_node", walks_per_node)
|
|
776
|
+
storage.add_attr("time_stamp_count", self.time_stamp_count)
|
|
777
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
|
778
|
+
storage.add_attr("attractive_RWs_name", attractive_RWs_name if include_attractive else None)
|
|
779
|
+
storage.add_attr("repulsive_RWs_name", repulsive_RWs_name if include_repulsive else None)
|
|
780
|
+
storage.add_attr("attractive_SAWs_name", attractive_SAWs_name if include_attractive else None)
|
|
781
|
+
storage.add_attr("repulsive_SAWs_name", repulsive_SAWs_name if include_repulsive else None)
|
|
782
|
+
storage.add_attr("walks_layout", "time_leading_3d") # (T, M, L+1) for all modes
|
|
783
|
+
|
|
784
|
+
_logger.info("Wrote metadata")
|
|
785
|
+
|
|
786
|
+
_logger.info("sample_walks complete -> %s", str(output_path))
|
|
787
|
+
return str(output_path)
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
__all__ = [
|
|
791
|
+
"Walker"
|
|
792
|
+
]
|
|
793
|
+
|
|
794
|
+
if __name__ == "__main__":
|
|
795
|
+
pass
|