sawnergy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -0,0 +1,795 @@
1
+ from __future__ import annotations
2
+
3
+ # third-pary
4
+ import numpy as np
5
+ # built-in
6
+ from pathlib import Path
7
+ from typing import Literal
8
+ from concurrent.futures import ProcessPoolExecutor
9
+ import logging
10
+ import os
11
+ # local
12
+ from . import walker_util
13
+ from .. import sawnergy_util
14
+
15
+ # *----------------------------------------------------*
16
+ # GLOBALS
17
+ # *----------------------------------------------------*
18
+
19
+ _logger = logging.getLogger(__name__)
20
+
21
+ # *----------------------------------------------------*
22
+ # CLASSES
23
+ # *----------------------------------------------------*
24
+
25
+ class Walker:
26
+ """Random-walk sampler over time-indexed **transition** matrices.
27
+
28
+ Loads per-timestamp stacks of attractive/repulsive **transition** matrices
29
+ (shape ``(T, N, N)``) previously written by the RIN builder, exposes
30
+ sampling for random walks (RW) and self-avoiding walks (SAW), and can
31
+ optionally advance a *time* coordinate using cosine-similarity between
32
+ transition slices (time-aware walks).
33
+
34
+ Matrices live in OS shared memory via :class:`walker_util.SharedNDArray`,
35
+ so multiple processes can read the same buffers zero-copy. Each `Walker`
36
+ instance owns a dedicated :class:`numpy.random.Generator` seeded from a
37
+ master seed for reproducibility.
38
+ """
39
+
40
+ def __init__(self,
41
+ RIN_path: str | Path,
42
+ *,
43
+ seed: int | None = None) -> None:
44
+ """Initialize shared matrices and RNG.
45
+
46
+ Data source:
47
+ Transition dataset names are resolved from the archive attributes
48
+ ``'attractive_transitions_name'`` and ``'repulsive_transitions_name'``.
49
+ If either attribute is ``None``, that channel is unavailable.
50
+
51
+ Args:
52
+ RIN_path: Path to an ``ArrayStorage`` archive (.zip) containing
53
+ **transition** matrices written by the builder.
54
+ seed: Optional master seed for this instance's RNG. If ``None``,
55
+ a random 32-bit seed is chosen.
56
+
57
+ Raises:
58
+ ValueError: If no matrices are found or arrays are not rank-3.
59
+ RuntimeError: If attractive/repulsive shapes differ or matrices
60
+ are not square along the last two axes.
61
+ """
62
+ _logger.info("Initializing Walker from %s", RIN_path)
63
+
64
+ # Load numpy arrays from read-only storage
65
+ with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
66
+ attr_name = storage.get_attr("attractive_transitions_name")
67
+ repuls_name = storage.get_attr("repulsive_transitions_name")
68
+ attr_matrices : np.ndarray | None = (
69
+ storage.read(attr_name, slice(None)) if attr_name is not None else None
70
+ )
71
+ repuls_matrices: np.ndarray | None = (
72
+ storage.read(repuls_name, slice(None)) if repuls_name is not None else None
73
+ )
74
+
75
+ _logger.debug(
76
+ "Loaded matrices | attr: shape=%s dtype=%s | repuls: shape=%s dtype=%s",
77
+ getattr(attr_matrices, "shape", None), getattr(attr_matrices, "dtype", None),
78
+ getattr(repuls_matrices, "shape", None), getattr(repuls_matrices, "dtype", None),
79
+ )
80
+
81
+ # Shape & consistency checks (expect (T, N, N))
82
+ if (attr_matrices is not None) and (repuls_matrices is not None):
83
+ if attr_matrices.ndim != 3 or repuls_matrices.ndim != 3:
84
+ _logger.error(
85
+ "Bad ranks: attr.ndim=%s repuls.ndim=%s; expected both 3",
86
+ getattr(attr_matrices, "ndim", None),
87
+ getattr(repuls_matrices, "ndim", None),
88
+ )
89
+ raise ValueError(
90
+ f"Expected (T,N,N) arrays; got {getattr(attr_matrices, 'shape', None)} "
91
+ f"and {getattr(repuls_matrices, 'shape', None)}"
92
+ )
93
+ if attr_matrices.shape != repuls_matrices.shape:
94
+ _logger.error("Shape mismatch: attr=%s repuls=%s",
95
+ attr_matrices.shape, repuls_matrices.shape)
96
+ raise RuntimeError(
97
+ f"ATTR/REPULS shapes must match exactly; got {attr_matrices.shape} vs {repuls_matrices.shape}"
98
+ )
99
+ T, N1, N2 = attr_matrices.shape
100
+ elif attr_matrices is not None:
101
+ if attr_matrices.ndim != 3:
102
+ raise ValueError(f"Expected (T,N,N); got {attr_matrices.shape}")
103
+ T, N1, N2 = attr_matrices.shape
104
+ elif repuls_matrices is not None:
105
+ if repuls_matrices.ndim != 3:
106
+ raise ValueError(f"Expected (T,N,N); got {repuls_matrices.shape}")
107
+ T, N1, N2 = repuls_matrices.shape
108
+ else:
109
+ _logger.error("No transition matrices detected in %s", RIN_path)
110
+ raise ValueError("No transition matrices detected.")
111
+
112
+ if N1 != N2:
113
+ _logger.error("Non-square matrices along last two dims: (%s, %s)", N1, N2)
114
+ raise RuntimeError(
115
+ f"Transition matrices must be square along last two dims; got ({N1}, {N2})"
116
+ )
117
+
118
+ _logger.info("Transition stack OK: T=%d, N=%d", T, N1)
119
+
120
+ # SHARED MEMORY ELEMENTS (read-only default views; fancy indexing via .array)
121
+ self.attr_matrices: walker_util.SharedNDArray | None = (
122
+ walker_util.SharedNDArray.create(
123
+ shape=attr_matrices.shape,
124
+ dtype=attr_matrices.dtype,
125
+ from_array=attr_matrices,
126
+ ) if attr_matrices is not None else None
127
+ )
128
+ self.repuls_matrices: walker_util.SharedNDArray | None = (
129
+ walker_util.SharedNDArray.create(
130
+ shape=repuls_matrices.shape,
131
+ dtype=repuls_matrices.dtype,
132
+ from_array=repuls_matrices,
133
+ ) if repuls_matrices is not None else None
134
+ )
135
+
136
+ _logger.debug(
137
+ "SharedNDArray created | attr name=%r; repuls name=%r",
138
+ getattr(self.attr_matrices, "name", None),
139
+ getattr(self.repuls_matrices, "name", None),
140
+ )
141
+
142
+ # AUXILIARY NETWORK INFORMATION
143
+ self.time_stamp_count = T
144
+ self.node_count = N1
145
+
146
+ # NETWORK ELEMENT
147
+ self.nodes = np.arange(0, self.node_count, 1, np.intp)
148
+ self.time_stamps = np.arange(0, self.time_stamp_count, 1, np.intp)
149
+ _logger.debug(
150
+ "Index arrays built: nodes=%d, time_stamps=%d",
151
+ self.nodes.size, self.time_stamps.size
152
+ )
153
+
154
+ # INTERNAL
155
+ self._memory_cleaned_up: bool = False
156
+ self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
157
+ self.rng = np.random.default_rng(self._seed)
158
+ _logger.info("RNG initialized (master seed=%d)", self._seed)
159
+
160
+ # explicit resource cleanup
161
+ def close(self) -> None:
162
+ """Close shared-memory handles and (in main process) unlink segments.
163
+
164
+ Idempotent: if cleanup already occurred, returns immediately. Always
165
+ closes local handles in the current process. If the caller is the main
166
+ process (per ``sawnergy_util.is_main_process()``), also attempts to
167
+ unlink the underlying shared-memory segments (best-effort; suppresses
168
+ ``FileNotFoundError`` if already unlinked elsewhere).
169
+ """
170
+ if self._memory_cleaned_up:
171
+ _logger.debug("close(): already cleaned up; returning")
172
+ return
173
+ _logger.debug("Closing Walker resources (is_main=%s)", sawnergy_util.is_main_process())
174
+ try:
175
+ if self.attr_matrices is not None:
176
+ self.attr_matrices.close()
177
+ if self.repuls_matrices is not None:
178
+ self.repuls_matrices.close()
179
+ _logger.debug("SharedNDArray handles closed")
180
+ if sawnergy_util.is_main_process():
181
+ _logger.debug("Attempting to unlink shared memory segments (main process)")
182
+ try:
183
+ if self.attr_matrices is not None:
184
+ self.attr_matrices.unlink()
185
+ except FileNotFoundError:
186
+ _logger.warning("attr SharedMemory already unlinked")
187
+ try:
188
+ if self.repuls_matrices is not None:
189
+ self.repuls_matrices.unlink()
190
+ except FileNotFoundError:
191
+ _logger.warning("repuls SharedMemory already unlinked")
192
+ else:
193
+ _logger.debug("Not main process; skipping unlink")
194
+ finally:
195
+ self._memory_cleaned_up = True
196
+ _logger.debug("Cleanup complete")
197
+
198
+ def __enter__(self):
199
+ """Enter context manager scope."""
200
+ _logger.debug("__enter__")
201
+ return self
202
+
203
+ def __exit__(self, exc_type, exc, tb):
204
+ """Exit context manager scope and perform cleanup."""
205
+ _logger.debug("__exit__(exc_type=%s)", getattr(exc_type, "__name__", exc_type))
206
+ self.close()
207
+
208
+ def __del__(self):
209
+ """Best-effort destructor cleanup (exceptions suppressed)."""
210
+ try:
211
+ if not getattr(self, "_memory_cleaned_up", True):
212
+ _logger.debug("__del__: best-effort close")
213
+ self.close()
214
+ except Exception as e:
215
+ _logger.debug("__del__ suppressed exception: %r", e)
216
+
217
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
218
+ # PRIVATE
219
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
220
+
221
+ def _matrices_of_interaction_type(
222
+ self, interaction_type: Literal["attr", "repuls"]
223
+ ) -> walker_util.SharedNDArray:
224
+ """Return the shared array wrapper for the requested interaction type.
225
+
226
+ Args:
227
+ interaction_type: Either ``"attr"`` or ``"repuls"``.
228
+
229
+ Returns:
230
+ SharedNDArray wrapper exposing the ``(T, N, N)`` stack.
231
+
232
+ Raises:
233
+ ValueError: If the channel is missing or the type is invalid.
234
+ """
235
+ _logger.debug("_matrices_of_interaction_type(%s)", interaction_type)
236
+ if interaction_type == "attr":
237
+ if self.attr_matrices is None:
238
+ raise ValueError(
239
+ "Attractive transition matrices are not present in the RIN archive."
240
+ )
241
+ return self.attr_matrices
242
+ if interaction_type == "repuls":
243
+ if self.repuls_matrices is None:
244
+ raise ValueError(
245
+ "Repulsive transition matrices are not present in the RIN archive."
246
+ )
247
+ return self.repuls_matrices
248
+ _logger.error("interaction_type invalid: %r", interaction_type)
249
+ raise ValueError("`interaction_type` must be 'attr' or 'repuls'.")
250
+
251
+ def _extract_prob_vector(
252
+ self,
253
+ node: int,
254
+ time_stamp: int,
255
+ interaction_type: Literal["attr", "repuls"],
256
+ ) -> np.ndarray:
257
+ """Copy the transition row for ``node`` at ``time_stamp``.
258
+
259
+ Returns:
260
+ ``(N,)`` float array with transition probabilities/weights.
261
+ """
262
+ _logger.debug("_extract_prob_vector(node=%d, t=%d, type=%s)",
263
+ node, time_stamp, interaction_type)
264
+ matrix = self._matrices_of_interaction_type(interaction_type)[time_stamp]
265
+ vec = matrix[node, :].copy() # detach from shared buffer
266
+ _logger.debug("prob vector extracted: shape=%s dtype=%s", vec.shape, vec.dtype)
267
+ return vec
268
+
269
+ def _step_node(
270
+ self,
271
+ node: int,
272
+ interaction_type: Literal["attr", "repuls"],
273
+ time_stamp: int = 0,
274
+ avoid: np.typing.ArrayLike | None = None,
275
+ ) -> tuple[int, np.ndarray | None]:
276
+ """Sample next node given current node and optional avoidance set.
277
+
278
+ Returns:
279
+ ``(next_node, updated_avoid)`` where ``updated_avoid`` is ``None`` if
280
+ avoidance is disabled.
281
+ """
282
+ _logger.debug(
283
+ "_step_node(node=%d, t=%d, type=%s, avoid_len=%s)",
284
+ node, time_stamp, interaction_type,
285
+ None if avoid is None else np.asarray(avoid).size,
286
+ )
287
+ prob_dist = self._extract_prob_vector(node, time_stamp, interaction_type)
288
+
289
+ if avoid is None:
290
+ mass = float(np.sum(prob_dist))
291
+ _logger.debug("_step_node: no-avoid branch; mass=%.6f", mass)
292
+ if mass <= 0.0:
293
+ _logger.error("_step_node: zero probability mass without avoidance")
294
+ raise RuntimeError("No valid node transitions: zero probability mass.")
295
+ return int(self.rng.choice(self.nodes, p=prob_dist)), None
296
+
297
+ to_avoid = np.asarray(avoid, dtype=np.intp)
298
+ keep = np.setdiff1d(self.nodes, to_avoid, assume_unique=False)
299
+ _logger.debug("_step_node: keep.size=%d (after removing %d avoids)",
300
+ keep.size, to_avoid.size)
301
+ if keep.size == 0:
302
+ _logger.error("_step_node: empty candidate set after avoidance")
303
+ raise RuntimeError("No available node transitions (avoiding all nodes).")
304
+
305
+ probs = walker_util.l1_norm(prob_dist[keep])
306
+ mass = float(probs.sum())
307
+ _logger.debug("_step_node: normalized mass=%.6f", mass)
308
+ if mass <= 0.0:
309
+ _logger.error("_step_node: zero probability mass after masking/normalization")
310
+ raise RuntimeError("No valid node transitions: probability mass is zero.")
311
+
312
+ next_node = int(self.rng.choice(keep, p=probs))
313
+ _logger.debug("_step_node: chosen next_node=%d", next_node)
314
+ to_avoid = np.append(to_avoid, next_node).astype(np.intp, copy=False)
315
+
316
+ return next_node, to_avoid
317
+
318
+ def _step_time(
319
+ self,
320
+ time_stamp: int,
321
+ interaction_type: Literal["attr", "repuls"],
322
+ stickiness: float,
323
+ on_no_options: Literal["raise", "loop"],
324
+ avoid: np.typing.ArrayLike | None,
325
+ ) -> tuple[int, np.ndarray | None]:
326
+ """Sample next time stamp given stickiness and similarity.
327
+
328
+ Raises:
329
+ ValueError: If ``stickiness`` not in ``[0,1]`` or ``on_no_options`` invalid.
330
+ RuntimeError: If no candidates or zero probability mass.
331
+ """
332
+ _logger.debug(
333
+ "_step_time(t=%d, type=%s, stickiness=%.3f, on_no_options=%s, avoid_len=%s)",
334
+ time_stamp, interaction_type, stickiness, on_no_options,
335
+ None if avoid is None else np.asarray(avoid).size,
336
+ )
337
+ if not (0.0 <= stickiness <= 1.0):
338
+ _logger.error("stickiness out of range: %r", stickiness)
339
+ raise ValueError("stickiness must be in [0,1]")
340
+
341
+ to_avoid = np.array([], dtype=np.intp) if avoid is None else np.asarray(avoid, dtype=np.intp)
342
+
343
+ # With probability `stickiness`, remain at the same time stamp
344
+ r = float(self.rng.random())
345
+ _logger.debug("_step_time: rand=%.6f vs stickiness=%.6f", r, float(stickiness))
346
+ if r < float(stickiness):
347
+ _logger.debug("_step_time: sticking at t=%d", time_stamp)
348
+ return int(time_stamp), to_avoid
349
+
350
+ # Exclude current time since we chose not to stick
351
+ to_avoid = np.unique(np.append(to_avoid, time_stamp).astype(np.intp, copy=False))
352
+ keep = np.setdiff1d(self.time_stamps, to_avoid, assume_unique=True)
353
+ _logger.debug("_step_time: keep.size=%d (to_avoid.size=%d)", keep.size, to_avoid.size)
354
+
355
+ matrices = self._matrices_of_interaction_type(interaction_type)
356
+ current_matrix = matrices[time_stamp] # axis-0 basic indexing returns a view
357
+
358
+ if keep.size == 0:
359
+ if on_no_options == "raise":
360
+ _logger.error("_step_time: no available timestamps (avoid=%s)", np.unique(to_avoid))
361
+ raise RuntimeError(f"No available time stamps (avoid={np.unique(to_avoid)})")
362
+ if on_no_options == "loop":
363
+ _logger.warning("_step_time: looping over all except current (t=%d)", time_stamp)
364
+ to_avoid = np.array([time_stamp], dtype=np.intp)
365
+ keep = self.time_stamps[self.time_stamps != time_stamp]
366
+ if keep.size == 0:
367
+ _logger.error("_step_time: loop mode impossible (T==1)")
368
+ raise RuntimeError("No alternative time stamps available for loop mode (T==1).")
369
+ matrices_stack = matrices.array[keep] # fancy indexing on ndarray, not wrapper
370
+ else:
371
+ _logger.error("_step_time: invalid on_no_options=%r", on_no_options)
372
+ raise ValueError("on_no_options must be 'raise' or 'loop'")
373
+ else:
374
+ matrices_stack = matrices.array[keep] # fancy indexing on ndarray, not wrapper
375
+
376
+ sims = walker_util.apply_on_axis0(matrices_stack, walker_util.cosine_similarity(current_matrix))
377
+ probs = walker_util.l1_norm(sims)
378
+ mass = float(probs.sum())
379
+ _logger.debug("_step_time: candidates=%d, mass=%.6f", keep.size, mass)
380
+ if mass <= 0.0:
381
+ _logger.error(
382
+ "_step_time: zero probability mass (t=%d, type=%s, candidates=%d)",
383
+ time_stamp, interaction_type, keep.size
384
+ )
385
+ raise RuntimeError(
386
+ "No valid time stamps to sample: probability mass is zero after masking/normalization."
387
+ )
388
+
389
+ next_time_stamp = int(self.rng.choice(keep, p=probs))
390
+ _logger.debug("_step_time: chosen next_time_stamp=%d", next_time_stamp)
391
+ return next_time_stamp, to_avoid
392
+
393
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
394
+ # PUBLIC
395
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
396
+
397
+ def walk(self,
398
+ start_node: int | None,
399
+ start_time_stamp: int | None,
400
+ length: int,
401
+ interaction_type: Literal["attr", "repuls"],
402
+ self_avoid: bool,
403
+ time_aware: bool = False,
404
+ stickiness: float | None = None,
405
+ on_no_options: Literal["raise", "loop"] | None = None) -> np.ndarray:
406
+ """Generate one walk path.
407
+
408
+ Indexing contract:
409
+ Public API is **1-based** for both nodes and time stamps to match
410
+ residue numbering in biomolecular contexts. Internally everything is
411
+ 0-based. The returned path is 1-based.
412
+
413
+ Args:
414
+ start_node: 1-based start node; if ``None``, sampled uniformly.
415
+ start_time_stamp: 1-based start time; if ``None``, sampled uniformly.
416
+ length: Number of transition **steps** to simulate.
417
+ interaction_type: ``"attr"`` or ``"repuls"``.
418
+ self_avoid: If ``True``, path will not revisit nodes within the same walk.
419
+ time_aware: If ``True``, advance time with :meth:`_step_time` each step.
420
+ stickiness: Required when ``time_aware=True``; probability of staying
421
+ at the current time.
422
+ on_no_options: Required when ``time_aware=True``; behavior when no
423
+ time candidates are available (``"raise"`` or ``"loop"``).
424
+
425
+ Returns:
426
+ ``(length + 1,)`` array of dtype ``intp`` with **1-based** node indices.
427
+
428
+ Raises:
429
+ ValueError: Bad start indices (after 1-based→0-based) or missing
430
+ time-aware parameters.
431
+ RuntimeError: Propagated from step routines when no valid choices exist.
432
+ """
433
+ _logger.debug(
434
+ "walk(start_node=%r, start_time_stamp=%r, length=%d, type=%s, self_avoid=%s, time_aware=%s)",
435
+ start_node, start_time_stamp, length, interaction_type, self_avoid, time_aware
436
+ )
437
+
438
+ # 1-based external API preserved, so validate ranges after conversion
439
+ if start_node is not None:
440
+ node = int(start_node) - 1
441
+ if not (0 <= node < self.node_count):
442
+ _logger.error("start_node out of range after 1-based conversion: %r", start_node)
443
+ raise ValueError(f"start_node out of range after 1-based conversion: {start_node}")
444
+ else:
445
+ node = int(self.rng.choice(self.nodes))
446
+
447
+ if start_time_stamp is not None:
448
+ time_stamp = int(start_time_stamp) - 1
449
+ if not (0 <= time_stamp < self.time_stamp_count):
450
+ _logger.error("start_time_stamp out of range after 1-based conversion: %r", start_time_stamp)
451
+ raise ValueError(f"start_time_stamp out of range after 1-based conversion: {start_time_stamp}")
452
+ else:
453
+ time_stamp = int(self.rng.choice(self.time_stamps))
454
+
455
+ _logger.debug("walk: initial node=%d, t=%d", node, time_stamp)
456
+
457
+ nodes_to_avoid: np.ndarray | None = np.array([node], dtype=np.intp) if self_avoid else None
458
+ time_stamps_to_avoid: np.ndarray | None = None
459
+
460
+ pth = np.array([node], dtype=np.intp)
461
+
462
+ if time_aware and (stickiness is None or on_no_options is None):
463
+ _logger.error("time_aware=True but stickiness/on_no_options missing")
464
+ raise ValueError("time_aware=True requires both `stickiness` and `on_no_options`.")
465
+
466
+ for _ in range(length):
467
+ if self_avoid:
468
+ node, nodes_to_avoid = self._step_node(node, interaction_type, time_stamp, nodes_to_avoid)
469
+ else:
470
+ node, _ = self._step_node(node, interaction_type, time_stamp, avoid=None)
471
+ pth = np.append(pth, node).astype(np.intp, copy=False)
472
+
473
+ if time_aware:
474
+ time_stamp, time_stamps_to_avoid = self._step_time(
475
+ time_stamp, interaction_type, stickiness, on_no_options, time_stamps_to_avoid
476
+ )
477
+
478
+ pth += 1 # ensure 1-based indexing in the output
479
+
480
+ _logger.debug("walk: finished path of len=%d", pth.size)
481
+ return pth
482
+
483
+ # deterministic per-batch worker: (start_nodes_batch, seedseq/int) -> stack of walks
484
+ def _walk_batch_with_seed(self, work_item, num_walks_from_each: int, *args, **kwargs):
485
+ """Worker: seed RNG and generate a batch of walks for a set of start nodes.
486
+
487
+ Args:
488
+ work_item: Tuple ``(start_nodes, seed_obj)`` where ``start_nodes`` is
489
+ an iterable of **0-based** node indices and ``seed_obj`` is an
490
+ ``np.random.SeedSequence`` or an ``int``.
491
+ num_walks_from_each: Number of walks to generate per start node. If 0,
492
+ returns an empty array with the correct width.
493
+
494
+ Returns:
495
+ ``(M, L+1)`` array of dtype ``uint16`` with 1-based node indices,
496
+ where ``M = len(start_nodes) * num_walks_from_each`` and ``L`` is
497
+ ``kwargs["length"]`` (defaults to 0 if absent). If no walks are
498
+ generated, returns ``(0, L+1)``.
499
+ """
500
+ start_nodes, seed_obj = work_item
501
+ _logger.debug(
502
+ "_walk_batch_with_seed: batch_size=%d, walks_each=%d",
503
+ np.asarray(start_nodes).size, int(num_walks_from_each)
504
+ )
505
+ self.rng = np.random.default_rng(seed_obj) # SeedSequence or int OK
506
+ start_nodes = np.asarray(start_nodes, dtype=np.intp)
507
+ out = []
508
+ for snode in start_nodes:
509
+ for _ in range(int(num_walks_from_each)):
510
+ out.append(self.walk(int(snode)+1, *args, **kwargs)) # 1-based API
511
+
512
+ if not out:
513
+ L = int(kwargs.get("length", 0))
514
+ return np.empty((0, L + 1), dtype=np.uint16)
515
+
516
+ arr = np.stack(out, axis=0).astype(np.uint16, copy=False)
517
+ _logger.debug("_walk_batch_with_seed: produced walks shape=%s dtype=%s", arr.shape, arr.dtype)
518
+ return arr
519
+
520
+ def _walks_per_time(self,
521
+ work_items,
522
+ processor,
523
+ num_walks_from_each: int,
524
+ *,
525
+ walk_length: int,
526
+ interaction_type: Literal["attr", "repuls"],
527
+ self_avoid: bool,
528
+ time_aware: bool,
529
+ stickiness: float | None,
530
+ on_no_options: Literal["raise", "loop"] | None) -> np.ndarray:
531
+ """
532
+ Generate walks separately for each start time stamp and stack as (T, M, L+1).
533
+
534
+ - If time_aware=False: each layer t contains walks constrained to time t.
535
+ - If time_aware=True: each layer t contains walks that *start* at t and
536
+ may traverse time via _step_time during the walk.
537
+ """
538
+ per_time: list[np.ndarray] = []
539
+ for t in range(self.time_stamp_count):
540
+ chunks = processor(
541
+ work_items,
542
+ self._walk_batch_with_seed,
543
+ int(num_walks_from_each),
544
+ start_time_stamp=int(t + 1),
545
+ length=walk_length,
546
+ interaction_type=interaction_type,
547
+ self_avoid=self_avoid,
548
+ time_aware=bool(time_aware),
549
+ stickiness=stickiness,
550
+ on_no_options=on_no_options,
551
+ )
552
+ if chunks:
553
+ all_walks_2d = np.concatenate(chunks, axis=0).astype(np.uint16, copy=False)
554
+ else:
555
+ all_walks_2d = np.empty((0, walk_length + 1), dtype=np.uint16)
556
+ per_time.append(all_walks_2d)
557
+ arr_3d = (np.stack(per_time, axis=0)
558
+ if per_time else np.empty((self.time_stamp_count, 0, walk_length + 1), dtype=np.uint16))
559
+ _logger.info("_walks_per_time: produced (T,M,L+1)=%s for type=%s, self_avoid=%s, time_aware=%s",
560
+ arr_3d.shape, interaction_type, self_avoid, time_aware)
561
+ return arr_3d
562
+
563
+ def sample_walks(self,
564
+ # walks
565
+ walk_length: int,
566
+ walks_per_node: int,
567
+ saw_frac: float = 0.0,
568
+ include_attractive: bool = True,
569
+ include_repulsive: bool = False,
570
+ # time aware params
571
+ time_aware: bool = False,
572
+ stickiness: float | None = None,
573
+ on_no_options: Literal["raise", "loop"] | None = None,
574
+ # output
575
+ output_path: str | Path | None = None,
576
+ *,
577
+ # computation
578
+ in_parallel: bool,
579
+ # storage
580
+ compression_level: int = 3,
581
+ num_walk_matrices_in_compressed_blocks: int | None = None
582
+ ) -> str:
583
+ """Generate and persist random walks for **all** nodes.
584
+
585
+ For each node, produces ``walks_per_node`` paths split between RW and
586
+ SAW according to ``saw_frac``. Optionally enables time-aware stepping.
587
+
588
+ Output layout:
589
+ Walk arrays are written as **3-D** with shape ``(T, M, L+1)``, where:
590
+ - ``T`` is the number of time stamps,
591
+ - ``M`` is the total number of walks produced per layer (sum over requested nodes),
592
+ - ``L+1`` is the path length including the start node.
593
+ If ``time_aware=False``, each layer t contains walks constrained to time t.
594
+ If ``time_aware=True``, each layer t contains walks that **start** at time t but may evolve in time.
595
+
596
+ Results are chunk-written to a new compressed archive.
597
+
598
+ Args:
599
+ walk_length: Number of steps per walk (path length is ``walk_length+1``).
600
+ walks_per_node: Total walks per node (integer).
601
+ saw_frac: Fraction in ``[0,1]`` of per-node walks that are SAWs
602
+ (remainder are RWs).
603
+ include_attractive: If ``True``, generate walks on the attractive channel.
604
+ include_repulsive: If ``True``, generate walks on the repulsive channel.
605
+ time_aware: If ``True``, enable time evolution via :meth:`_step_time`.
606
+ stickiness: Required when ``time_aware=True``; probability of staying
607
+ at the current time step.
608
+ on_no_options: Required when ``time_aware=True``; when no alternative
609
+ time stamps are available, either ``"raise"`` or ``"loop"``.
610
+ output_path: Destination (with or without ``.zip``). Defaults to
611
+ ``WALKS_<timestamp>.zip`` in the current working directory.
612
+ in_parallel: Use :class:`ProcessPoolExecutor` to parallelize over
613
+ node batches (requires main-process guard).
614
+ compression_level: Compression level for the output archive.
615
+ num_walk_matrices_in_compressed_blocks: Max number of walk matrices
616
+ per compressed chunk when writing. Defaults to number of batches.
617
+
618
+ Returns:
619
+ String path to the written archive.
620
+
621
+ Raises:
622
+ ValueError: If ``saw_frac`` is outside ``[0,1]``.
623
+ RuntimeError: If run in parallel without a main-process guard, or when
624
+ no valid transitions are available during stepping.
625
+ """
626
+ _logger.info(
627
+ "sample_walks: L=%d, per_node=%d, saw_frac=%.3f, time_aware=%s, out=%s, "
628
+ "parallel=%s, arrays_per_chunk=%s",
629
+ walk_length, walks_per_node, saw_frac, time_aware, output_path, in_parallel,
630
+ num_walk_matrices_in_compressed_blocks,
631
+ )
632
+
633
+ current_time = sawnergy_util.current_time()
634
+
635
+ output_path = Path((output_path or (Path(os.getcwd()) /
636
+ f"WALKS_{current_time}"))).with_suffix(".zip")
637
+ _logger.debug("Output archive path: %s", output_path)
638
+
639
+ if not (0.0 <= saw_frac <= 1.0):
640
+ _logger.error("saw_frac out of range: %r", saw_frac)
641
+ raise ValueError("saw_frac must be in [0, 1]")
642
+
643
+ # Deterministic integer split
644
+ num_SAWs = int(round(walks_per_node * float(saw_frac)))
645
+ num_RWs = int(walks_per_node) - num_SAWs
646
+ _logger.info("Per-node counts: SAWs=%d, RWs=%d", num_SAWs, num_RWs)
647
+
648
+ num_workers = os.cpu_count() or 1
649
+ batch_size_nodes = (num_workers if in_parallel else 1)
650
+ _logger.debug("Workers=%d, batch_size_nodes=%d", num_workers, batch_size_nodes)
651
+
652
+ if in_parallel and not sawnergy_util.is_main_process():
653
+ _logger.error("Process-based parallelism requires main-process guard")
654
+ raise RuntimeError(
655
+ "Process-based parallelism requires running under `if __name__ == '__main__':`."
656
+ )
657
+
658
+ processor = sawnergy_util.elementwise_processor(
659
+ in_parallel=in_parallel,
660
+ Executor=ProcessPoolExecutor,
661
+ max_workers=num_workers,
662
+ capture_output=True
663
+ )
664
+ _logger.debug("elementwise_processor created (parallel=%s, workers=%d)", in_parallel, num_workers)
665
+
666
+ # Pre-build node batches deterministically
667
+ _logger.debug("Building node batches via sawnergy_util.batches_of (batch_size_nodes=%d)", batch_size_nodes)
668
+ node_batches = list(sawnergy_util.batches_of(self.nodes, batch_size=batch_size_nodes))
669
+ _logger.debug("Built %d node batches", len(node_batches))
670
+
671
+ # Derive deterministic child seeds from master seed — stable per batch
672
+ master_ss = np.random.SeedSequence(self._seed)
673
+ child_seeds = master_ss.spawn(len(node_batches))
674
+ work_items = list(zip(node_batches, child_seeds))
675
+ _logger.debug("Prepared %d work_items with child seeds", len(work_items))
676
+
677
+ num_walk_matrices_in_compressed_blocks = (
678
+ num_walk_matrices_in_compressed_blocks or len(node_batches)
679
+ )
680
+ _logger.info("arrays_per_chunk resolved to: %d", num_walk_matrices_in_compressed_blocks)
681
+
682
+ attractive_RWs_name = "ATTRACTIVE_RWs"
683
+ attractive_SAWs_name = "ATTRACTIVE_SAWs"
684
+ repulsive_RWs_name = "REPULSIVE_RWs"
685
+ repulsive_SAWs_name = "REPULSIVE_SAWs"
686
+
687
+ with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level) as storage:
688
+ if include_attractive:
689
+ # --- ATTR RWs ---
690
+ _logger.info("Generating ATTR RWs ...")
691
+
692
+ attr_RWs_3d = self._walks_per_time(
693
+ work_items, processor, num_RWs,
694
+ walk_length=walk_length,
695
+ interaction_type="attr",
696
+ self_avoid=False,
697
+ time_aware=time_aware,
698
+ stickiness=stickiness,
699
+ on_no_options=on_no_options,
700
+ )
701
+ _logger.info("ATTR RWs (per-time): shape=%s", attr_RWs_3d.shape)
702
+ storage.write(
703
+ attr_RWs_3d,
704
+ to_block_named=attractive_RWs_name,
705
+ arrays_per_chunk=num_walk_matrices_in_compressed_blocks
706
+ )
707
+
708
+ # --- ATTR SAWs ---
709
+ _logger.info("Generating ATTR SAWs ...")
710
+ attr_SAWs_3d = self._walks_per_time(
711
+ work_items, processor, num_SAWs,
712
+ walk_length=walk_length,
713
+ interaction_type="attr",
714
+ self_avoid=True,
715
+ time_aware=time_aware,
716
+ stickiness=stickiness,
717
+ on_no_options=on_no_options,
718
+ )
719
+ _logger.info("ATTR SAWs (per-time): shape=%s", attr_SAWs_3d.shape)
720
+ storage.write(
721
+ attr_SAWs_3d,
722
+ to_block_named=attractive_SAWs_name,
723
+ arrays_per_chunk=num_walk_matrices_in_compressed_blocks
724
+ )
725
+
726
+ if include_repulsive:
727
+ # --- REPULS RWs ---
728
+ _logger.info("Generating REPULS RWs ...")
729
+ repuls_RWs_3d = self._walks_per_time(
730
+ work_items, processor, num_RWs,
731
+ walk_length=walk_length,
732
+ interaction_type="repuls",
733
+ self_avoid=False,
734
+ time_aware=time_aware,
735
+ stickiness=stickiness,
736
+ on_no_options=on_no_options,
737
+ )
738
+ _logger.info("REPULS RWs (per-time): shape=%s", repuls_RWs_3d.shape)
739
+ storage.write(
740
+ repuls_RWs_3d,
741
+ to_block_named=repulsive_RWs_name,
742
+ arrays_per_chunk=num_walk_matrices_in_compressed_blocks
743
+ )
744
+
745
+ # --- REPULS SAWs ---
746
+ _logger.info("Generating REPULS SAWs ...")
747
+ repuls_SAWs_3d = self._walks_per_time(
748
+ work_items, processor, num_SAWs,
749
+ walk_length=walk_length,
750
+ interaction_type="repuls",
751
+ self_avoid=True,
752
+ time_aware=time_aware,
753
+ stickiness=stickiness,
754
+ on_no_options=on_no_options,
755
+ )
756
+ _logger.info("REPULS SAWs (per-time): shape=%s", repuls_SAWs_3d.shape)
757
+ storage.write(
758
+ repuls_SAWs_3d,
759
+ to_block_named=repulsive_SAWs_name,
760
+ arrays_per_chunk=num_walk_matrices_in_compressed_blocks
761
+ )
762
+
763
+ # useful metadata
764
+ storage.add_attr("time_created", current_time)
765
+ storage.add_attr("seed", int(self._seed))
766
+ storage.add_attr("rng_scheme", "SeedSequence.spawn_per_batch_v1")
767
+ storage.add_attr("num_workers", int(num_workers))
768
+ storage.add_attr("in_parallel", bool(in_parallel))
769
+ storage.add_attr("batch_size_nodes", int(batch_size_nodes))
770
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
771
+ storage.add_attr("num_RWs", num_RWs)
772
+ storage.add_attr("num_SAWs", num_SAWs)
773
+ storage.add_attr("node_count", self.node_count)
774
+ storage.add_attr("walk_length", walk_length)
775
+ storage.add_attr("walks_per_node", walks_per_node)
776
+ storage.add_attr("time_stamp_count", self.time_stamp_count)
777
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
778
+ storage.add_attr("attractive_RWs_name", attractive_RWs_name if include_attractive else None)
779
+ storage.add_attr("repulsive_RWs_name", repulsive_RWs_name if include_repulsive else None)
780
+ storage.add_attr("attractive_SAWs_name", attractive_SAWs_name if include_attractive else None)
781
+ storage.add_attr("repulsive_SAWs_name", repulsive_SAWs_name if include_repulsive else None)
782
+ storage.add_attr("walks_layout", "time_leading_3d") # (T, M, L+1) for all modes
783
+
784
+ _logger.info("Wrote metadata")
785
+
786
+ _logger.info("sample_walks complete -> %s", str(output_path))
787
+ return str(output_path)
788
+
789
+
790
+ __all__ = [
791
+ "Walker"
792
+ ]
793
+
794
+ if __name__ == "__main__":
795
+ pass