sawnergy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/__init__.py +13 -0
- sawnergy/embedding/SGNS_pml.py +135 -0
- sawnergy/embedding/SGNS_torch.py +177 -0
- sawnergy/embedding/__init__.py +34 -0
- sawnergy/embedding/embedder.py +578 -0
- sawnergy/logging_util.py +54 -0
- sawnergy/rin/__init__.py +9 -0
- sawnergy/rin/rin_builder.py +936 -0
- sawnergy/rin/rin_util.py +391 -0
- sawnergy/sawnergy_util.py +1182 -0
- sawnergy/visual/__init__.py +42 -0
- sawnergy/visual/visualizer.py +690 -0
- sawnergy/visual/visualizer_util.py +387 -0
- sawnergy/walks/__init__.py +16 -0
- sawnergy/walks/walker.py +795 -0
- sawnergy/walks/walker_util.py +384 -0
- sawnergy-1.0.0.dist-info/METADATA +290 -0
- sawnergy-1.0.0.dist-info/RECORD +22 -0
- sawnergy-1.0.0.dist-info/WHEEL +5 -0
- sawnergy-1.0.0.dist-info/licenses/LICENSE +201 -0
- sawnergy-1.0.0.dist-info/licenses/NOTICE +4 -0
- sawnergy-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# third-pary
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.colors as mcolors
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
# built-in
|
|
6
|
+
from typing import Iterable
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
# *----------------------------------------------------*
|
|
10
|
+
# GLOBALS
|
|
11
|
+
# *----------------------------------------------------*
|
|
12
|
+
|
|
13
|
+
_logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# DISCRETE
|
|
16
|
+
BLUE = "#3B82F6" # Tailwind Blue 500
|
|
17
|
+
GREEN = "#10B981" # Emerald Green
|
|
18
|
+
RED = "#EF4444" # Soft Red
|
|
19
|
+
YELLOW = "#FACC15" # Amber Yellow
|
|
20
|
+
PURPLE = "#8B5CF6" # Vibrant Purple
|
|
21
|
+
PINK = "#EC4899" # Modern Pink
|
|
22
|
+
TEAL = "#14B8A6" # Teal
|
|
23
|
+
ORANGE = "#F97316" # Bright Orange
|
|
24
|
+
CYAN = "#06B6D4" # Cyan
|
|
25
|
+
INDIGO = "#6366F1" # Indigo
|
|
26
|
+
GRAY = "#6B7280" # Neutral Gray
|
|
27
|
+
LIME = "#84CC16" # Lime Green
|
|
28
|
+
ROSE = "#F43F5E" # Rose
|
|
29
|
+
SKY = "#0EA5E9" # Sky Blue
|
|
30
|
+
SLATE = "#475569" # Slate Gray
|
|
31
|
+
|
|
32
|
+
# CONTINUOUS SPECTRUM
|
|
33
|
+
HEAT = "autumn"
|
|
34
|
+
COLD = "winter"
|
|
35
|
+
|
|
36
|
+
# *----------------------------------------------------*
|
|
37
|
+
# FUNCTIONS
|
|
38
|
+
# *----------------------------------------------------*
|
|
39
|
+
|
|
40
|
+
# -=-=-=-=-=-=-=-=-=-=-=- #
|
|
41
|
+
# CONVENIENCE
|
|
42
|
+
# -=-=-=-=-=-=-=-=-=-=-=- #
|
|
43
|
+
|
|
44
|
+
def ensure_backend(show: bool) -> None:
|
|
45
|
+
"""
|
|
46
|
+
If the user asked to show a window but no GUI is available, switch to Agg and warn.
|
|
47
|
+
Must be called *before* importing matplotlib.pyplot.
|
|
48
|
+
"""
|
|
49
|
+
import os, sys, matplotlib, warnings, logging
|
|
50
|
+
headless = (
|
|
51
|
+
sys.platform.startswith("linux")
|
|
52
|
+
and not os.environ.get("DISPLAY")
|
|
53
|
+
and not os.environ.get("WAYLAND_DISPLAY")
|
|
54
|
+
)
|
|
55
|
+
if show and headless:
|
|
56
|
+
matplotlib.use("Agg", force=True)
|
|
57
|
+
warnings.warn(
|
|
58
|
+
"No GUI/display detected. Falling back to non-interactive 'Agg' backend. "
|
|
59
|
+
"Figures will be saved to files instead of shown."
|
|
60
|
+
)
|
|
61
|
+
logging.getLogger(__name__).warning(
|
|
62
|
+
"Headless environment detected; switched Matplotlib backend to 'Agg'."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def warm_start_matplotlib() -> None:
|
|
66
|
+
"""Prime Matplotlib caches and the 3D pipeline.
|
|
67
|
+
|
|
68
|
+
This function performs a lightweight warm-up to avoid the first-draw stall
|
|
69
|
+
often seen in Matplotlib, especially when using 3D axes and colorbars.
|
|
70
|
+
It preloads the font manager and triggers a minimal 3D render.
|
|
71
|
+
|
|
72
|
+
Side Effects:
|
|
73
|
+
Initializes Matplotlib's font cache and issues a tiny 3D draw with a
|
|
74
|
+
colorbar, then closes the temporary figure.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
This function intentionally swallows all exceptions and logs them at
|
|
78
|
+
DEBUG level; nothing is raised to the caller.
|
|
79
|
+
"""
|
|
80
|
+
_logger.debug("warm_start_matplotlib: starting.")
|
|
81
|
+
try:
|
|
82
|
+
from matplotlib import font_manager
|
|
83
|
+
_ = font_manager.findSystemFonts()
|
|
84
|
+
_ = font_manager.FontManager()
|
|
85
|
+
_logger.debug("warm_start_matplotlib: font manager primed.")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
_logger.debug("warm_start_matplotlib: font warmup failed: %s", e)
|
|
88
|
+
try:
|
|
89
|
+
# tiny 3D figure + colormap + initial render
|
|
90
|
+
f = plt.figure(figsize=(1, 1))
|
|
91
|
+
ax = f.add_subplot(111, projection="3d")
|
|
92
|
+
ax.plot([0, 1], [0, 1], [0, 1])
|
|
93
|
+
f.colorbar(plt.cm.ScalarMappable(cmap="viridis"), ax=ax, fraction=0.2, pad=0.04)
|
|
94
|
+
f.canvas.draw_idle()
|
|
95
|
+
plt.pause(0.01)
|
|
96
|
+
plt.close(f)
|
|
97
|
+
_logger.debug("warm_start_matplotlib: 3D pipeline primed.")
|
|
98
|
+
except Exception as e:
|
|
99
|
+
_logger.debug("warm_start_matplotlib: 3D warmup failed: %s", e)
|
|
100
|
+
|
|
101
|
+
def map_groups_to_colors(N: int,
|
|
102
|
+
groups: tuple[tuple[Iterable[int], str]] | None,
|
|
103
|
+
default_color: str,
|
|
104
|
+
one_based: bool = True):
|
|
105
|
+
"""Map index groups to RGBA colors.
|
|
106
|
+
|
|
107
|
+
Builds an RGBA color array of length ``N`` initialized to ``default_color``,
|
|
108
|
+
then overwrites entries specified by the provided groups.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
N: Total number of items (length of the output color array).
|
|
112
|
+
groups: An optional tuple of ``(indices, color_hex)`` pairs, where
|
|
113
|
+
``indices`` is any iterable of int indices and ``color_hex`` is a
|
|
114
|
+
Matplotlib-parsable color (e.g., ``"#EF4444"``). If ``None``, all
|
|
115
|
+
entries are set to ``default_color``.
|
|
116
|
+
default_color: Fallback color used for all indices not covered by
|
|
117
|
+
``groups``.
|
|
118
|
+
one_based: If ``True``, the indices in each group are interpreted as
|
|
119
|
+
1-based and will be converted internally to 0-based. If ``False``,
|
|
120
|
+
indices are used as-is.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
list[tuple[float, float, float, float]]: A list of RGBA tuples of length
|
|
124
|
+
``N`` suitable for Matplotlib facecolors.
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
IndexError: If any provided index is out of ``[0, N-1]`` after
|
|
128
|
+
converting from 1-based (when ``one_based=True``).
|
|
129
|
+
|
|
130
|
+
Notes:
|
|
131
|
+
- No deduplication is performed across groups; later groups overwrite
|
|
132
|
+
earlier ones for the same index.
|
|
133
|
+
"""
|
|
134
|
+
_logger.debug("map_groups_to_colors: N=%s, groups=%s, default_color=%s, one_based=%s",
|
|
135
|
+
N, None if groups is None else len(groups), default_color, one_based)
|
|
136
|
+
base = mcolors.to_rgba(default_color)
|
|
137
|
+
colors = [base for _ in range(N)]
|
|
138
|
+
if groups is not None:
|
|
139
|
+
for indices, hex_color in groups:
|
|
140
|
+
col = mcolors.to_rgba(hex_color)
|
|
141
|
+
for idx in indices:
|
|
142
|
+
i = (idx - 1) if one_based else idx
|
|
143
|
+
if not (0 <= i < N):
|
|
144
|
+
_logger.error("map_groups_to_colors: index %s out of range for N=%s", idx, N)
|
|
145
|
+
raise IndexError(f"Index {idx} out of range for N={N}")
|
|
146
|
+
colors[i] = col
|
|
147
|
+
_logger.debug("map_groups_to_colors: completed.")
|
|
148
|
+
return colors
|
|
149
|
+
|
|
150
|
+
# -=-=-=-=-=-=-=-=-=-=-=- #
|
|
151
|
+
# SCENE CONSTRUCTION
|
|
152
|
+
# -=-=-=-=-=-=-=-=-=-=-=- #
|
|
153
|
+
|
|
154
|
+
def absolute_quantile(N: int, weights: np.ndarray, frac: float) -> float:
|
|
155
|
+
"""Compute a global upper-triangle weight quantile threshold.
|
|
156
|
+
|
|
157
|
+
Extracts the upper triangular (k=1) entries of the ``weights`` matrix for a
|
|
158
|
+
graph of size ``N`` and returns the quantile corresponding to
|
|
159
|
+
``1.0 - frac``. For example, ``frac=0.25`` yields the 75th percentile.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
N: Number of nodes (size of the square ``weights`` matrix).
|
|
163
|
+
weights: 2D array of shape ``(N, N)`` containing edge weights.
|
|
164
|
+
frac: Fraction in ``[0, 1]`` representing the *top* share of edges to
|
|
165
|
+
keep (e.g., 0.01 means top 1%). Internally converts to the
|
|
166
|
+
``1.0 - frac`` quantile.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
float: The threshold value such that edges >= threshold correspond
|
|
170
|
+
approximately to the top ``frac`` of the (upper-triangle) weights.
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ValueError: If ``weights`` has incompatible shape with ``N`` (not
|
|
174
|
+
explicitly validated here, but downstream NumPy may raise).
|
|
175
|
+
"""
|
|
176
|
+
_logger.debug("absolute_quantile: N=%s, weights.shape=%s, frac=%s",
|
|
177
|
+
N, getattr(weights, "shape", None), frac)
|
|
178
|
+
r, c = np.triu_indices(N, k=1)
|
|
179
|
+
vals = weights[r, c]
|
|
180
|
+
if vals.size == 0:
|
|
181
|
+
_logger.debug("absolute_quantile: no upper-tri edges; returning 0.0")
|
|
182
|
+
return 0.0
|
|
183
|
+
q = float(np.quantile(vals, 1.0 - frac))
|
|
184
|
+
_logger.debug("absolute_quantile: computed threshold=%s over %d edges", q, vals.size)
|
|
185
|
+
return q
|
|
186
|
+
|
|
187
|
+
def row_wise_norm(weights: np.ndarray) -> np.ndarray:
|
|
188
|
+
"""Normalize an adjacency/weight matrix row-wise.
|
|
189
|
+
|
|
190
|
+
Each row is divided by its row sum. Rows with zero sum become all zeros
|
|
191
|
+
(no NaNs/inf).
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
weights: 2D array, typically ``(N, N)`` adjacency/weight matrix.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
np.ndarray: Same shape as ``weights`` with each row summing to ~1, or
|
|
198
|
+
all zeros for zero-sum rows.
|
|
199
|
+
"""
|
|
200
|
+
_logger.debug("row_wise_norm: weights.shape=%s", getattr(weights, "shape", None))
|
|
201
|
+
sums = np.sum(weights, axis=1, keepdims=True)
|
|
202
|
+
out = np.zeros_like(weights, dtype=float)
|
|
203
|
+
np.divide(weights, sums, out=out, where=(sums != 0))
|
|
204
|
+
try:
|
|
205
|
+
_logger.debug("row_wise_norm: row_sums[min=%.6g, max=%.6g]", float(sums.min()), float(sums.max()))
|
|
206
|
+
except Exception:
|
|
207
|
+
_logger.debug("row_wise_norm: row_sums stats unavailable.")
|
|
208
|
+
return out
|
|
209
|
+
|
|
210
|
+
def absolute_norm(weights: np.ndarray) -> np.ndarray:
|
|
211
|
+
"""Normalize an array by its total sum.
|
|
212
|
+
|
|
213
|
+
Divides all entries by the total sum. If the total is zero/non-finite,
|
|
214
|
+
returns an all-zeros array (no NaNs/inf).
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
weights: Array (any shape). Often an ``(N, N)`` matrix of edge weights.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
np.ndarray: Same shape as ``weights`` whose values sum to ~1, or zeros
|
|
221
|
+
if the total is zero/non-finite.
|
|
222
|
+
"""
|
|
223
|
+
_logger.debug("absolute_norm: weights.shape=%s", getattr(weights, "shape", None))
|
|
224
|
+
total = np.sum(weights)
|
|
225
|
+
if not np.isfinite(total) or total == 0:
|
|
226
|
+
_logger.debug("absolute_norm: total_sum is zero/non-finite; returning zeros")
|
|
227
|
+
return np.zeros_like(weights, dtype=float)
|
|
228
|
+
out = weights / total
|
|
229
|
+
try:
|
|
230
|
+
_logger.debug("absolute_norm: total_sum=%.6g", float(total))
|
|
231
|
+
except Exception:
|
|
232
|
+
_logger.debug("absolute_norm: total_sum unavailable.")
|
|
233
|
+
return out
|
|
234
|
+
|
|
235
|
+
def build_line_segments(
|
|
236
|
+
N: int,
|
|
237
|
+
include: np.ndarray,
|
|
238
|
+
coords: np.ndarray,
|
|
239
|
+
weights: np.ndarray,
|
|
240
|
+
top_frac_weights_displayed: float,
|
|
241
|
+
*,
|
|
242
|
+
global_weights_frac: bool = True,
|
|
243
|
+
global_opacity: bool = True,
|
|
244
|
+
global_color_saturation: bool = True
|
|
245
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
246
|
+
"""Construct 3D line segments and per-edge style weights.
|
|
247
|
+
|
|
248
|
+
Builds edge segments between selected node pairs and returns arrays used for
|
|
249
|
+
color and opacity weighting. The selection is performed by:
|
|
250
|
+
1) keeping only edges whose both endpoints are in ``include``, then
|
|
251
|
+
2) thresholding by a quantile on weights (globally or only among candidate
|
|
252
|
+
edges), keeping approximately the top ``top_frac_weights_displayed``.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
N: Total number of nodes (size of the weight matrix).
|
|
256
|
+
include: 1D array of node indices to consider (0-based).
|
|
257
|
+
coords: Array of shape ``(N, 3)`` with per-node 3D coordinates used to
|
|
258
|
+
construct line segments.
|
|
259
|
+
weights: 2D array of shape ``(N, N)`` with edge weights.
|
|
260
|
+
top_frac_weights_displayed: Fraction in ``[0, 1]`` specifying how many
|
|
261
|
+
of the heaviest edges to keep (approximately).
|
|
262
|
+
global_weights_frac: If ``True``, the quantile threshold is computed
|
|
263
|
+
from **all** upper-triangle weights. If ``False``, it is computed
|
|
264
|
+
only from candidate edges (those with both endpoints in
|
|
265
|
+
``include``).
|
|
266
|
+
global_opacity: If ``True``, opacity weights are derived from
|
|
267
|
+
``row_wise_norm(weights)``. If ``False``, they are derived only from
|
|
268
|
+
the kept edges (others zeroed) before normalizing.
|
|
269
|
+
global_color_saturation: If ``True``, color weights are derived from
|
|
270
|
+
``absolute_norm(weights)``. If ``False``, they are derived only from
|
|
271
|
+
the kept edges (others zeroed) before normalizing.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
tuple:
|
|
275
|
+
- ``line_segments`` (np.ndarray): Shape ``(E, 2, 3)`` where each row
|
|
276
|
+
contains the ``[from_xyz, to_xyz]`` coordinates for one edge.
|
|
277
|
+
- ``color_weights`` (np.ndarray): Shape ``(E,)`` scalar weights
|
|
278
|
+
intended for colormap mapping (e.g., saturation).
|
|
279
|
+
- ``opacity_weights`` (np.ndarray): Shape ``(E,)`` scalar weights
|
|
280
|
+
intended for alpha/opacity.
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
ValueError: If ``coords`` is not shape ``(N, 3)``.
|
|
284
|
+
"""
|
|
285
|
+
_logger.debug(
|
|
286
|
+
"build_line_segments: N=%s, include.len=%s, coords.shape=%s, weights.shape=%s, top_frac=%s, "
|
|
287
|
+
"global_weights_frac=%s, global_opacity=%s, global_color_saturation=%s",
|
|
288
|
+
N,
|
|
289
|
+
None if include is None else np.asarray(include).size,
|
|
290
|
+
getattr(coords, "shape", None),
|
|
291
|
+
getattr(weights, "shape", None),
|
|
292
|
+
top_frac_weights_displayed,
|
|
293
|
+
global_weights_frac, global_opacity, global_color_saturation
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Candidate edges
|
|
297
|
+
rows, cols = np.triu_indices(N, k=1)
|
|
298
|
+
|
|
299
|
+
# Endpoint filter: keep edges whose BOTH endpoints are in 'include'
|
|
300
|
+
inc_mask = np.zeros(N, dtype=bool)
|
|
301
|
+
inc_idx = np.asarray(include, dtype=int)
|
|
302
|
+
inc_mask[inc_idx] = True
|
|
303
|
+
edge_mask = inc_mask[rows] & inc_mask[cols]
|
|
304
|
+
rows, cols = rows[edge_mask], cols[edge_mask]
|
|
305
|
+
|
|
306
|
+
if rows.size == 0:
|
|
307
|
+
_logger.debug("build_line_segments: no candidate edges after endpoint filter; returning empties.")
|
|
308
|
+
return (np.empty((0, 2, 3), dtype=float),
|
|
309
|
+
np.empty((0,), dtype=float),
|
|
310
|
+
np.empty((0,), dtype=float))
|
|
311
|
+
|
|
312
|
+
edge_weights = weights[rows, cols]
|
|
313
|
+
|
|
314
|
+
# Threshold: global vs local (displayed-only) quantile
|
|
315
|
+
if global_weights_frac:
|
|
316
|
+
thresh = absolute_quantile(N, weights, top_frac_weights_displayed)
|
|
317
|
+
else:
|
|
318
|
+
thresh = float(np.quantile(edge_weights, 1.0 - top_frac_weights_displayed))
|
|
319
|
+
kept = edge_weights >= thresh
|
|
320
|
+
rows, cols = rows[kept], cols[kept]
|
|
321
|
+
|
|
322
|
+
if rows.size == 0:
|
|
323
|
+
_logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
|
|
324
|
+
return (np.empty((0, 2, 3), dtype=float),
|
|
325
|
+
np.empty((0,), dtype=float),
|
|
326
|
+
np.empty((0,), dtype=float))
|
|
327
|
+
|
|
328
|
+
# Build a matrix containing ONLY the kept edges (others zeroed)
|
|
329
|
+
displayed_weights = np.zeros_like(weights)
|
|
330
|
+
displayed_weights[rows, cols] = weights[rows, cols]
|
|
331
|
+
displayed_weights[cols, rows] = weights[rows, cols] # keep symmetry for row sums
|
|
332
|
+
|
|
333
|
+
# Opacity weights: global vs displayed-only
|
|
334
|
+
if global_opacity:
|
|
335
|
+
opacity_weights = row_wise_norm(weights)[rows, cols]
|
|
336
|
+
else:
|
|
337
|
+
opacity_weights = row_wise_norm(displayed_weights)[rows, cols]
|
|
338
|
+
|
|
339
|
+
# Color weights: global vs displayed-only (absolute normalization)
|
|
340
|
+
if global_color_saturation:
|
|
341
|
+
color_weights = absolute_norm(weights)[rows, cols]
|
|
342
|
+
else:
|
|
343
|
+
color_weights = absolute_norm(displayed_weights)[rows, cols]
|
|
344
|
+
|
|
345
|
+
# Coordinates: EXPECT (N, 3)
|
|
346
|
+
coords = np.asarray(coords)
|
|
347
|
+
if coords.shape[0] != N:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"`coords` must be shape (N, 3) with N={N}. "
|
|
350
|
+
"If you spread only displayed nodes, create a copy of the full frame coords and "
|
|
351
|
+
"overwrite those displayed rows before calling this function."
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Segments (E, 2, 3)
|
|
355
|
+
line_segments = np.stack([coords[rows], coords[cols]], axis=1)
|
|
356
|
+
|
|
357
|
+
_logger.debug("build_line_segments: segments.shape=%s, color_w.shape=%s, opacity_w.shape=%s, thresh=%.6g, kept=%d",
|
|
358
|
+
getattr(line_segments, "shape", None),
|
|
359
|
+
getattr(color_weights, "shape", None),
|
|
360
|
+
getattr(opacity_weights, "shape", None),
|
|
361
|
+
thresh, rows.size)
|
|
362
|
+
|
|
363
|
+
return line_segments, color_weights, opacity_weights
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
__all__ = [
|
|
367
|
+
"BLUE",
|
|
368
|
+
"GREEN",
|
|
369
|
+
"RED",
|
|
370
|
+
"YELLOW",
|
|
371
|
+
"PURPLE",
|
|
372
|
+
"PINK",
|
|
373
|
+
"TEAL",
|
|
374
|
+
"ORANGE",
|
|
375
|
+
"CYAN",
|
|
376
|
+
"INDIGO",
|
|
377
|
+
"GRAY",
|
|
378
|
+
"LIME",
|
|
379
|
+
"ROSE",
|
|
380
|
+
"SKY",
|
|
381
|
+
"SLATE",
|
|
382
|
+
"HEAT",
|
|
383
|
+
"COLD"
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
if __name__ == "__main__":
|
|
387
|
+
pass
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# walker
|
|
2
|
+
from .walker import Walker
|
|
3
|
+
from .walker_util import (
|
|
4
|
+
SharedNDArray,
|
|
5
|
+
l1_norm,
|
|
6
|
+
apply_on_axis0,
|
|
7
|
+
cosine_similarity
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Walker",
|
|
12
|
+
"SharedNDArray",
|
|
13
|
+
"l1_norm",
|
|
14
|
+
"apply_on_axis0",
|
|
15
|
+
"cosine_similarity"
|
|
16
|
+
]
|