PyGhostID 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- PyGhostID/__init__.py +19 -0
- PyGhostID/_utils.py +416 -0
- PyGhostID/core.py +1330 -0
- pyghostid-1.0.0.dist-info/METADATA +19 -0
- pyghostid-1.0.0.dist-info/RECORD +8 -0
- pyghostid-1.0.0.dist-info/WHEEL +5 -0
- pyghostid-1.0.0.dist-info/licenses/LICENSE +674 -0
- pyghostid-1.0.0.dist-info/top_level.txt +1 -0
PyGhostID/core.py
ADDED
|
@@ -0,0 +1,1330 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Core functions of PyGhostID
|
|
4
|
+
|
|
5
|
+
@author: Daniel Koch, 2026
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Import packages
|
|
9
|
+
import numpy as np
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from scipy.signal import find_peaks
|
|
13
|
+
from scipy.spatial import cKDTree
|
|
14
|
+
from scipy.optimize import (
|
|
15
|
+
minimize,
|
|
16
|
+
differential_evolution,
|
|
17
|
+
dual_annealing,
|
|
18
|
+
basinhopping,
|
|
19
|
+
)
|
|
20
|
+
from scipy.stats import qmc
|
|
21
|
+
|
|
22
|
+
from scipy.integrate import solve_ivp
|
|
23
|
+
import matplotlib.pyplot as plt
|
|
24
|
+
import networkx as nx
|
|
25
|
+
from networkx.drawing.nx_agraph import graphviz_layout
|
|
26
|
+
from ._utils import *
|
|
27
|
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
|
28
|
+
import os
|
|
29
|
+
import sys
|
|
30
|
+
from tqdm import tqdm
|
|
31
|
+
|
|
32
|
+
# Functions
|
|
33
|
+
|
|
34
|
+
def ghostID(model, params, dt, trajectory, epsilon_gid=0.05, **kwargs):
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
Identify ghost states along a simulated trajectory by detecting Q-minima and
|
|
38
|
+
evaluating the eigenvalue spectrum of the Jacobian along trajectory segments
|
|
39
|
+
in their vicinity.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
model : callable
|
|
44
|
+
Python function describing the system dynamics.
|
|
45
|
+
params : list or array-like
|
|
46
|
+
Parameters passed as arguments to model.
|
|
47
|
+
dt : float
|
|
48
|
+
Step-size of the numerical integration used to simulate trajectory.
|
|
49
|
+
trajectory : array-like
|
|
50
|
+
A trajectory simulated by the system to be analysed for ghost states.
|
|
51
|
+
epsilon_gid : float, optional
|
|
52
|
+
Radius of the epsilon-sphere around Q-minima. Determines which trajectory
|
|
53
|
+
segments are used for eigenvalue evaluation: only points within this distance
|
|
54
|
+
of a Q-minimum are included. Values in the range 0.01--0.1 are reasonable
|
|
55
|
+
for many models (default is 0.05). The value should be large enough that
|
|
56
|
+
segments contain sufficient points for reliable eigenvalue estimation, but
|
|
57
|
+
small enough that eigenvalues remain representative of the local phase-space
|
|
58
|
+
topology around the candidate ghost. A practical tuning strategy is to start
|
|
59
|
+
with small values and increase until eigenvalue control plots look reasonably
|
|
60
|
+
smooth.
|
|
61
|
+
**kwargs
|
|
62
|
+
Optional keyword arguments:
|
|
63
|
+
|
|
64
|
+
delta_gid : float
|
|
65
|
+
Phase-space distance below which two ghosts identified by GhostID are
|
|
66
|
+
considered the same and assigned the same identifier. Default is 0.1.
|
|
67
|
+
peak_kwargs : dict
|
|
68
|
+
Additional keyword arguments passed to scipy.signal.find_peaks to
|
|
69
|
+
improve detection of Q-minima from peaks in the pQ time-series.
|
|
70
|
+
evLimit : float
|
|
71
|
+
Default is 0 (disabled). If set to a value greater than 0, enables the
|
|
72
|
+
indirect method of ghost identification: a trajectory segment is
|
|
73
|
+
considered to pass near a ghost if (1) the absolute value of the median
|
|
74
|
+
of eigenvalues along the segment is below evLimit, (2) the linear fit of
|
|
75
|
+
eigenvalues has R² >= 0.99, and (3) the slope of the fit lies within the
|
|
76
|
+
range given by slopeLimits.
|
|
77
|
+
slopeLimits : array-like of length 2
|
|
78
|
+
Upper and lower bounds for the eigenvalue slope used in the indirect
|
|
79
|
+
identification method. Ignored if evLimit = 0. Default is [0, inf].
|
|
80
|
+
batchModel : callable
|
|
81
|
+
Vectorized (batch) version of model able to handle batch inputs. Can be
|
|
82
|
+
provided to improve performance when ghostID is called repeatedly with the
|
|
83
|
+
same model, avoiding repeated calls to make_batch_model. If not provided,
|
|
84
|
+
ghostID constructs a batch model internally via make_batch_model.
|
|
85
|
+
|
|
86
|
+
Eigenvalue cleaning
|
|
87
|
+
-------------------
|
|
88
|
+
Because eigenvalue indexing along a trajectory segment is not guaranteed to
|
|
89
|
+
be consistent across consecutive time steps (i.e. lambda_1 at time t may be
|
|
90
|
+
lambda_2 at time t+1), two complementary correction methods are available:
|
|
91
|
+
|
|
92
|
+
ev_outlier_removal : bool
|
|
93
|
+
If True, removes eigenvalue outliers detected within a sliding window.
|
|
94
|
+
An eigenvalue is considered an outlier if it falls above q3 + k*(q3-q1)
|
|
95
|
+
or below q1 - k*(q3-q1), where q1 and q3 are the 25th and 75th
|
|
96
|
+
percentiles and k is set by ev_outlier_k. Default is False.
|
|
97
|
+
ev_outlier_removal_k : float
|
|
98
|
+
Controls the width of the non-outlier range (see ev_outlier_removal).
|
|
99
|
+
Default is 1.5.
|
|
100
|
+
ev_outlier_removal_ws : float
|
|
101
|
+
Size of the sliding window used for outlier detection. Default is 7.
|
|
102
|
+
eigval_NN_sorting : bool
|
|
103
|
+
If True, sorts eigenvalues across time for each index i using a
|
|
104
|
+
nearest-neighbour prediction. Useful when eigenvalue time-series appear
|
|
105
|
+
scattered or discontinuous. Default is False.
|
|
106
|
+
|
|
107
|
+
Control outputs
|
|
108
|
+
---------------
|
|
109
|
+
display_warnings : bool
|
|
110
|
+
Show or suppress warning messages from GhostID.
|
|
111
|
+
ctrlOutputs : dict
|
|
112
|
+
Controls diagnostic plots of the algorithm's two core quantities
|
|
113
|
+
(pQ-values and eigenvalues). Recognised keys:
|
|
114
|
+
|
|
115
|
+
ctrl_qplot (bool) : plot pQ-values and detected Q-minima along
|
|
116
|
+
the trajectory.
|
|
117
|
+
qplot_xscale (str) : x-axis scale for Q-plot, 'linear' (default)
|
|
118
|
+
or 'log'.
|
|
119
|
+
qplot_yscale (str) : y-axis scale for Q-plot, 'linear' (default)
|
|
120
|
+
or 'log'.
|
|
121
|
+
ctrl_evplot (bool) : plot eigenvalues along each trajectory segment
|
|
122
|
+
around identified Q-minima, including the
|
|
123
|
+
evaluation criteria listed in the plot heading.
|
|
124
|
+
evplot_xscale (str) : x-axis scale for eigenvalue plot, 'linear'
|
|
125
|
+
(default) or 'log'.
|
|
126
|
+
evplot_yscale (str) : y-axis scale for eigenvalue plot, 'linear'
|
|
127
|
+
(default) or 'log'.
|
|
128
|
+
return_ctrl_figs : bool
|
|
129
|
+
If True, returns control plot figures for manual customisation instead
|
|
130
|
+
of displaying them inline. Default is False.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
ghostSeq : list of dict
|
|
135
|
+
List of identified ghost states, each represented as a Python dictionary.
|
|
136
|
+
control_figures : optional
|
|
137
|
+
Control plot figures, returned only if return_ctrl_figs is True.
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
# Parse and validate kwargs
|
|
142
|
+
config = parse_kwargs(**kwargs)
|
|
143
|
+
|
|
144
|
+
# Extract parameters from config
|
|
145
|
+
display_warnings = config['display_warnings']
|
|
146
|
+
delta_gid = config['delta_gid']
|
|
147
|
+
peak_kwargs = config['peak_kwargs']
|
|
148
|
+
if "width" not in peak_kwargs:
|
|
149
|
+
peak_kwargs["width"] = 5 * dt # default if not supplied
|
|
150
|
+
batchModel = config['batchModel']
|
|
151
|
+
eigval_NN_sorting = config['eigval_NN_sorting']
|
|
152
|
+
ev_outlier_removal = config['ev_outlier_removal']
|
|
153
|
+
ev_outlier_removal_ws = config['ev_outlier_removal_ws']
|
|
154
|
+
ev_outlier_removal_k = config['ev_outlier_removal_k']
|
|
155
|
+
evLimit = config['evLimit']
|
|
156
|
+
slopeLimits = config['slopeLimits']
|
|
157
|
+
|
|
158
|
+
# Plotting control settings
|
|
159
|
+
return_ctrl_figs = config['return_ctrl_figs']
|
|
160
|
+
ctrl_qplot = config['ctrl_qplot']
|
|
161
|
+
qplot_xscale = config['qplot_xscale']
|
|
162
|
+
qplot_yscale = config['qplot_yscale']
|
|
163
|
+
ctrl_evplot = config['ctrl_evplot']
|
|
164
|
+
evplot_xscale = config['evplot_xscale']
|
|
165
|
+
evplot_yscale = config['evplot_yscale']
|
|
166
|
+
|
|
167
|
+
####################################
|
|
168
|
+
|
|
169
|
+
# Handle batch model
|
|
170
|
+
if batchModel is not None:
|
|
171
|
+
Xs = batchModel(trajectory)
|
|
172
|
+
else:
|
|
173
|
+
model_batch = make_batch_model(model, params)
|
|
174
|
+
Xs = model_batch(trajectory)
|
|
175
|
+
|
|
176
|
+
if return_ctrl_figs:
|
|
177
|
+
ctrl_figures = []
|
|
178
|
+
|
|
179
|
+
n = trajectory.shape[1] # dimension of trajectory
|
|
180
|
+
fullTransientSeq = [] # list of visited transient states to be filled
|
|
181
|
+
|
|
182
|
+
############# STEP 1 - identify non-oscillatory saddle-node ghosts #############################
|
|
183
|
+
|
|
184
|
+
### Identify minima in Q-values along trajectory using batch model output
|
|
185
|
+
|
|
186
|
+
Q_ts = 0.5 * np.sqrt(np.sum(Xs**2, axis=1))
|
|
187
|
+
pQ = -np.log(Q_ts)
|
|
188
|
+
|
|
189
|
+
idx_minima, pk_props = find_peaks(pQ, **peak_kwargs) # positions of Q-minima
|
|
190
|
+
|
|
191
|
+
if ctrl_qplot:
|
|
192
|
+
if not len(idx_minima)>0:
|
|
193
|
+
t_axis = np.arange(len(Q_ts)) * dt
|
|
194
|
+
fig, ax = plt.subplots()
|
|
195
|
+
fig.set_size_inches(17/(2*2.54),17/(3*2.54))
|
|
196
|
+
ax.plot(t_axis, pQ, '-k', lw=0.8, label="pQ(t)")
|
|
197
|
+
ax.set_ylabel("pQ(t)")
|
|
198
|
+
ax.set_xlabel("t")
|
|
199
|
+
ax.set_xscale(qplot_xscale)
|
|
200
|
+
ax.set_yscale(qplot_yscale)
|
|
201
|
+
ax.legend(fontsize = 9)
|
|
202
|
+
ax.set_title("Detected Q-minima with prominences",fontsize=12)
|
|
203
|
+
plt.tight_layout()
|
|
204
|
+
if return_ctrl_figs:
|
|
205
|
+
ctrl_figures.append((fig,ax))
|
|
206
|
+
plt.close(fig)
|
|
207
|
+
else:
|
|
208
|
+
plt.show()
|
|
209
|
+
|
|
210
|
+
# Precompile JAX Jacobian function
|
|
211
|
+
J_fun = make_jacfun(model, params)
|
|
212
|
+
|
|
213
|
+
# Build KD-tree once for trajectory
|
|
214
|
+
kdtree = cKDTree(trajectory)
|
|
215
|
+
|
|
216
|
+
if len(idx_minima) > 0:
|
|
217
|
+
ghostSeq = [] # sequence of visited ghosts
|
|
218
|
+
ghostTimes = [] # times at which ghosts have been visited - important for sorting ghosts and oscillatory transients
|
|
219
|
+
ghostCoordinates = [] # unique phase-space positions of all ghosts visited
|
|
220
|
+
|
|
221
|
+
if ctrl_qplot:
|
|
222
|
+
|
|
223
|
+
t_axis = np.arange(len(Q_ts)) * dt
|
|
224
|
+
fig, ax = plt.subplots()
|
|
225
|
+
fig.set_size_inches(17/(2*2.54),17/(3*2.54))
|
|
226
|
+
ax.plot(t_axis, pQ, '-k', lw=0.8, label="pQ(t)")
|
|
227
|
+
ax.plot(idx_minima * dt, pQ[idx_minima], 'xr', label="Q-minima")
|
|
228
|
+
# Plot prominence lines
|
|
229
|
+
if "prominences" in pk_props:
|
|
230
|
+
prominences = pk_props["prominences"]
|
|
231
|
+
for idx, prom in zip(idx_minima, prominences):
|
|
232
|
+
x = idx * dt
|
|
233
|
+
peak_val = pQ[idx]
|
|
234
|
+
base_val = peak_val - prom
|
|
235
|
+
# vertical line showing prominence
|
|
236
|
+
ax.vlines(x, base_val, peak_val, color="gray", linestyle="--")
|
|
237
|
+
# add text next to line
|
|
238
|
+
ax.text(x, base_val - 0.05 * prom, f"{prom:.2f}",
|
|
239
|
+
ha="center", va="top", fontsize=7.5, color="gray")
|
|
240
|
+
|
|
241
|
+
ax.set_ylabel("pQ(t)")
|
|
242
|
+
ax.set_xlabel("t")
|
|
243
|
+
ax.set_xscale(qplot_xscale)
|
|
244
|
+
ax.set_yscale(qplot_yscale)
|
|
245
|
+
ax.legend(fontsize = 9)
|
|
246
|
+
ax.set_title("Detected Q-minima with prominences",fontsize=12)
|
|
247
|
+
plt.tight_layout()
|
|
248
|
+
if return_ctrl_figs:
|
|
249
|
+
ctrl_figures.append((fig,ax))
|
|
250
|
+
plt.close(fig)
|
|
251
|
+
else:
|
|
252
|
+
plt.show()
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
for i in idx_minima:
|
|
256
|
+
|
|
257
|
+
ghostCheck = False
|
|
258
|
+
|
|
259
|
+
t_ghost = i * dt # time at which a potential ghost was found
|
|
260
|
+
dur_ghost = pk_props["widths"][np.where(idx_minima==i)[0][0]]*dt # trapping time
|
|
261
|
+
|
|
262
|
+
qmin_xyz = trajectory[i] # position of Q_minimum in phase space
|
|
263
|
+
|
|
264
|
+
# KD-tree neighborhood query
|
|
265
|
+
idcs_Ueps_qmin = kdtree.query_ball_point(qmin_xyz, epsilon_gid)
|
|
266
|
+
idcs_Ueps_qmin = np.sort(np.asarray(idcs_Ueps_qmin, dtype=int))
|
|
267
|
+
|
|
268
|
+
if len(idcs_Ueps_qmin)<5:
|
|
269
|
+
print("ghostID error: insuffienct number of points in Ueps!")
|
|
270
|
+
if not(return_ctrl_figs):
|
|
271
|
+
return []
|
|
272
|
+
else:
|
|
273
|
+
return [], ctrl_figures
|
|
274
|
+
|
|
275
|
+
idcs_segment = trjSegment(idcs_Ueps_qmin, i)
|
|
276
|
+
|
|
277
|
+
# Check if trajectory leaves the epsilon environment
|
|
278
|
+
leaves_eps_qmin_i = False
|
|
279
|
+
dists = np.linalg.norm(trajectory[i:] - qmin_xyz, axis=1)
|
|
280
|
+
if np.any(dists > epsilon_gid):
|
|
281
|
+
leaves_eps_qmin_i = True
|
|
282
|
+
|
|
283
|
+
if leaves_eps_qmin_i:
|
|
284
|
+
# Batch Jacobian + eigenvalue evaluation for segment
|
|
285
|
+
pts_segment = jnp.asarray(trajectory[idcs_segment]) # JAX array
|
|
286
|
+
J_batch = jax.vmap(J_fun)(pts_segment) # batch Jacobians
|
|
287
|
+
eigVals = jax.vmap(jnp.linalg.eigvals)(J_batch) # eigenvalues
|
|
288
|
+
eigVals_real_ = np.real(np.asarray(eigVals)) # back to numpy for analysis
|
|
289
|
+
|
|
290
|
+
if eigval_NN_sorting:
|
|
291
|
+
eigVals_real = sort_NN(eigVals_real_.T).T
|
|
292
|
+
else:
|
|
293
|
+
eigVals_real = eigVals_real_
|
|
294
|
+
|
|
295
|
+
# Determine eigenvalue crossings along the segment
|
|
296
|
+
ev_signChanges = [sign_change(eigVals_real[:, ii],ev_outlier_removal,ev_outlier_removal_ws,ev_outlier_removal_k,display_warnings=display_warnings) for ii in range(n)]
|
|
297
|
+
crossings = sum(ev_signChanges)
|
|
298
|
+
|
|
299
|
+
# determine eigenvalue slopes
|
|
300
|
+
qualifyingSlopes = []
|
|
301
|
+
|
|
302
|
+
if ctrl_evplot: r2s = []
|
|
303
|
+
for ii in range(n):
|
|
304
|
+
# Only consider eigenvalues with small median real part along segment
|
|
305
|
+
if np.abs(np.median(eigVals_real[:, ii])) < evLimit:
|
|
306
|
+
slope, r2 = slope_and_r2(eigVals_real[:, ii], dt,ev_outlier_removal,ev_outlier_removal_ws,ev_outlier_removal_k)
|
|
307
|
+
if ctrl_evplot: r2s.append(r2)
|
|
308
|
+
if np.all((slope > slopeLimits[0]) & (slope < slopeLimits[1]) & (r2 >= 0.99)):
|
|
309
|
+
qualifyingSlopes.append(ii)
|
|
310
|
+
|
|
311
|
+
# Check for ghost
|
|
312
|
+
if crossings > 0 or len(qualifyingSlopes) > 0:
|
|
313
|
+
ghostCheck = True
|
|
314
|
+
|
|
315
|
+
if ctrl_evplot:
|
|
316
|
+
n_eig = eigVals_real.shape[1]
|
|
317
|
+
fig, axes = plt.subplots(n_eig, 1, figsize=(17/(2*2.54), 17/(4*2.54)*n_eig), sharex=True)
|
|
318
|
+
if n_eig == 1: axes = [axes] # ensure axes is iterable
|
|
319
|
+
|
|
320
|
+
t_seg = np.arange(len(eigVals_real)) * dt
|
|
321
|
+
|
|
322
|
+
for j, ax in enumerate(axes):
|
|
323
|
+
ax.plot(t_seg, eigVals_real[:, j], '-ok', markersize=1.5, lw=0.75)
|
|
324
|
+
ax.set_ylabel(f'λ{j+1}')
|
|
325
|
+
|
|
326
|
+
axes[-1].set_xlabel('Time along segment')
|
|
327
|
+
if evLimit > 0:
|
|
328
|
+
qsl = len(qualifyingSlopes)
|
|
329
|
+
else:
|
|
330
|
+
qsl = 'N/A'
|
|
331
|
+
|
|
332
|
+
Q_mami = np.max(Q_ts[idcs_segment])/Q_ts[i]
|
|
333
|
+
|
|
334
|
+
plt.suptitle(
|
|
335
|
+
f"Eig.vals near Qmin at t={t_ghost:.2f}, ghost: {str(ghostCheck)[0]}, leaves Uɛ: {str(leaves_eps_qmin_i)[0]}, Qmami: {Q_mami:.1e} \n"
|
|
336
|
+
f"sign changes λi: " + "".join(np.where(ev_signChanges, 'T ', 'F '))+f", qualifying slopes: {qsl}, "
|
|
337
|
+
f"R²: {[f'{ri:.3f}' for ri in r2s]}", fontsize=9)
|
|
338
|
+
|
|
339
|
+
ax.set_xscale(evplot_xscale)
|
|
340
|
+
ax.set_yscale(evplot_yscale)
|
|
341
|
+
plt.tight_layout()
|
|
342
|
+
if return_ctrl_figs:
|
|
343
|
+
ctrl_figures.append((fig,axes))
|
|
344
|
+
plt.close(fig)
|
|
345
|
+
else:
|
|
346
|
+
plt.show()
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
# If ghost found, characterize its dimension and check if it has been found previously
|
|
350
|
+
if ghostCheck:
|
|
351
|
+
ghostTimes.append(t_ghost)
|
|
352
|
+
gdim = max([crossings, len(qualifyingSlopes)])
|
|
353
|
+
|
|
354
|
+
if len(ghostCoordinates) > 0:
|
|
355
|
+
# Calculate distances to all previously found ghosts
|
|
356
|
+
distances = np.asarray([np.linalg.norm(g - trajectory[i]) for g in ghostCoordinates])
|
|
357
|
+
|
|
358
|
+
if not any(d < delta_gid for d in distances): # current ghost has not been found previously
|
|
359
|
+
ghostCoordinates.append(trajectory[i])
|
|
360
|
+
ghost = {
|
|
361
|
+
"id": "G" + str(len(ghostCoordinates)), # assign ID to the new ghost
|
|
362
|
+
"time": t_ghost,
|
|
363
|
+
"duration": dur_ghost,
|
|
364
|
+
"position": trajectory[i],
|
|
365
|
+
"dimension": gdim,
|
|
366
|
+
"q-value": Q_ts[i],
|
|
367
|
+
"crossing_eigenvalues": np.where(np.array(ev_signChanges)==True)[0],
|
|
368
|
+
"qualifying_slopes":qualifyingSlopes,
|
|
369
|
+
"eigenvalues_qmin": np.asarray(eigVals[i, :])
|
|
370
|
+
}
|
|
371
|
+
else: # current ghost has already been found previously
|
|
372
|
+
gidx = np.where(distances < delta_gid)[0][0] + 1
|
|
373
|
+
ghost = {
|
|
374
|
+
"id": "G" + str(gidx),
|
|
375
|
+
"time": t_ghost,
|
|
376
|
+
"duration": dur_ghost,
|
|
377
|
+
"position": trajectory[i],
|
|
378
|
+
"dimension": gdim,
|
|
379
|
+
"q-value": Q_ts[i],
|
|
380
|
+
"crossing_eigenvalues": np.where(np.array(ev_signChanges)==True)[0],
|
|
381
|
+
"qualifying_slopes":qualifyingSlopes,
|
|
382
|
+
"eigenvalues_qmin": np.asarray(eigVals[i, :])
|
|
383
|
+
}
|
|
384
|
+
ghostSeq.append(ghost)
|
|
385
|
+
else: # No ghost previously found yet
|
|
386
|
+
ghost = {
|
|
387
|
+
"id": "G" + str(len(ghostCoordinates) + 1),
|
|
388
|
+
"time": t_ghost,
|
|
389
|
+
"duration": dur_ghost,
|
|
390
|
+
"position": trajectory[i],
|
|
391
|
+
"dimension": gdim,
|
|
392
|
+
"q-value": Q_ts[i],
|
|
393
|
+
"crossing_eigenvalues": np.where(np.array(ev_signChanges)==True)[0],
|
|
394
|
+
"qualifying_slopes":qualifyingSlopes,
|
|
395
|
+
"eigenvalues_qmin": np.asarray(eigVals[i, :])
|
|
396
|
+
}
|
|
397
|
+
ghostSeq.append(ghost)
|
|
398
|
+
ghostCoordinates.append(trajectory[i])
|
|
399
|
+
else:
|
|
400
|
+
if display_warnings:
|
|
401
|
+
print("GhostID: Trajectory does not leave U_eps - stopping ghostID.")
|
|
402
|
+
break
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
############# STEP 2 - identify oscillatory transients #############################
|
|
406
|
+
oscSeq = []
|
|
407
|
+
oscTimes = []
|
|
408
|
+
|
|
409
|
+
# Merge transient state lists
|
|
410
|
+
allTimes = np.asarray(ghostTimes + oscTimes)
|
|
411
|
+
allTimes.sort()
|
|
412
|
+
ghostTimes = np.asarray(ghostTimes)
|
|
413
|
+
oscTimes = np.asarray(oscTimes)
|
|
414
|
+
|
|
415
|
+
for t in allTimes:
|
|
416
|
+
i_t = np.where(ghostTimes == t)[0]
|
|
417
|
+
if len(i_t) == 1:
|
|
418
|
+
fullTransientSeq.append(ghostSeq[i_t[0]])
|
|
419
|
+
else:
|
|
420
|
+
i_t = np.where(oscTimes == t)[0]
|
|
421
|
+
fullTransientSeq.append(oscSeq[i_t[0]])
|
|
422
|
+
|
|
423
|
+
if not(return_ctrl_figs):
|
|
424
|
+
return fullTransientSeq
|
|
425
|
+
else:
|
|
426
|
+
return fullTransientSeq, ctrl_figures
|
|
427
|
+
|
|
428
|
+
def ghostID_phaseSpaceSample(model, model_params, t_start, t_end, dt, state_ranges,
|
|
429
|
+
n_samples=50, method='RK45', rtol=1.e-3, atol=1.e-6, n_workers=None, **kwargs):
|
|
430
|
+
"""
|
|
431
|
+
Identify ghost states across a region of phase space by simulating multiple
|
|
432
|
+
trajectories from initial conditions drawn by Latin-hypercube sampling and
|
|
433
|
+
evaluating each trajectory with ghostID.
|
|
434
|
+
|
|
435
|
+
Implements an adaptive parallel-processing routine: uses threads when run
|
|
436
|
+
inside Jupyter or Spyder (avoiding pickling issues) and processes when run
|
|
437
|
+
as a standalone script for full CPU utilisation.
|
|
438
|
+
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
441
|
+
model : callable
|
|
442
|
+
Python function describing the system dynamics.
|
|
443
|
+
model_params : list or array-like
|
|
444
|
+
Parameters passed as arguments to model.
|
|
445
|
+
t_start : float
|
|
446
|
+
Start time of each simulated trajectory.
|
|
447
|
+
t_end : float
|
|
448
|
+
End time of each simulated trajectory.
|
|
449
|
+
dt : float
|
|
450
|
+
Step-size for numerical integration.
|
|
451
|
+
state_ranges : list of tuple
|
|
452
|
+
List of n tuples (one per dimension of the system), each specifying the
|
|
453
|
+
(lower, upper) boundaries of the phase space region to be sampled along
|
|
454
|
+
that coordinate. For example, for a 2D system:
|
|
455
|
+
[(x_min, x_max), (y_min, y_max)].
|
|
456
|
+
n_samples : int, optional
|
|
457
|
+
Number of trajectories to simulate and analyse for ghost states.
|
|
458
|
+
method : str, optional
|
|
459
|
+
Integration method passed to scipy.integrate.solve_ivp. Default is 'RK45'.
|
|
460
|
+
rtol : float, optional
|
|
461
|
+
Relative tolerance for the numerical integrator. Default is 1e-3.
|
|
462
|
+
atol : float, optional
|
|
463
|
+
Absolute tolerance for the numerical integrator. Default is 1e-6.
|
|
464
|
+
n_workers : int or None, optional
|
|
465
|
+
Number of CPU cores used for parallel processing. If None, the number
|
|
466
|
+
of workers is chosen automatically. Default is None.
|
|
467
|
+
**kwargs
|
|
468
|
+
All keyword arguments accepted by ghostID are also accepted here (see
|
|
469
|
+
ghostID documentation). Note that control plots are unlikely to render
|
|
470
|
+
correctly when parallel processing is enabled.
|
|
471
|
+
|
|
472
|
+
Additional kwargs specific to ghostID_phaseSpaceSample:
|
|
473
|
+
|
|
474
|
+
delta_unify : float
|
|
475
|
+
Distance in phase space above which two ghosts identified across
|
|
476
|
+
different trajectories in the phase space sample are considered
|
|
477
|
+
distinct and assigned different identifiers. Analogous to delta_gid
|
|
478
|
+
in ghostID, but applied globally across all trajectories rather than
|
|
479
|
+
within a single trajectory. Default is 0.1.
|
|
480
|
+
seed : int or None
|
|
481
|
+
Random seed for the Latin-hypercube sampler. Setting a specific value
|
|
482
|
+
ensures exact reproducibility of the phase space sample. Default is
|
|
483
|
+
None (non-reproducible random sample).
|
|
484
|
+
|
|
485
|
+
Returns
|
|
486
|
+
-------
|
|
487
|
+
results : list
|
|
488
|
+
List containing a ghostSeq (see ghostID documentation) for each
|
|
489
|
+
simulated trajectory, allowing ghost states to be retrieved and compared
|
|
490
|
+
across the sampled region of phase space.
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
# Parse and validate kwargs
|
|
494
|
+
config = parse_kwargs(**kwargs)
|
|
495
|
+
|
|
496
|
+
display_warnings = config['display_warnings']
|
|
497
|
+
|
|
498
|
+
# Extract parameters from config
|
|
499
|
+
delta_gid = config['delta_gid']
|
|
500
|
+
peak_kwargs = config['peak_kwargs']
|
|
501
|
+
if "width" not in peak_kwargs:
|
|
502
|
+
peak_kwargs["width"] = 5 * dt # default if not supplied
|
|
503
|
+
batchModel = config['batchModel']
|
|
504
|
+
eigval_NN_sorting = config['eigval_NN_sorting']
|
|
505
|
+
ev_outlier_removal_ws = config['ev_outlier_removal_ws']
|
|
506
|
+
ev_outlier_removal_k = config['ev_outlier_removal_k']
|
|
507
|
+
evLimit = config['evLimit']
|
|
508
|
+
slopeLimits = config['slopeLimits']
|
|
509
|
+
epsilon_gid = config['epsilon_gid']
|
|
510
|
+
epsilon_unify = config['epsilon_unify'] # rename to delta_unify
|
|
511
|
+
seed = config['seed']
|
|
512
|
+
|
|
513
|
+
# Plotting control settings
|
|
514
|
+
return_ctrl_figs = config['return_ctrl_figs']
|
|
515
|
+
ctrlOutputs = kwargs.get("ctrlOutputs", {})
|
|
516
|
+
|
|
517
|
+
if n_workers is None:
|
|
518
|
+
n_workers = max(1, (os.cpu_count() or 4) - 1)
|
|
519
|
+
|
|
520
|
+
npts = int(t_end / dt)
|
|
521
|
+
t_eval = np.linspace(0, t_end, npts + 1)
|
|
522
|
+
|
|
523
|
+
ICs = phaseSpaceLHS(state_ranges, n_samples, seed)
|
|
524
|
+
|
|
525
|
+
# ---- Choose backend automatically ----
|
|
526
|
+
in_spyder_or_jupyter = (
|
|
527
|
+
"SPYDER" in sys.modules or
|
|
528
|
+
"spyder_kernels" in sys.modules or
|
|
529
|
+
"ipykernel" in sys.modules
|
|
530
|
+
)
|
|
531
|
+
Executor = ThreadPoolExecutor if in_spyder_or_jupyter else ProcessPoolExecutor
|
|
532
|
+
mode = "threads" if Executor is ThreadPoolExecutor else "processes"
|
|
533
|
+
print(f"[ghostID_phaseSpaceSample] Running with {mode} ({n_workers} workers)")
|
|
534
|
+
|
|
535
|
+
# ---- Define worker ----
|
|
536
|
+
def process_ic(ic):
|
|
537
|
+
sol = solve_ivp(model, (t_start, t_end), ic,
|
|
538
|
+
t_eval=t_eval, args=(model_params,),
|
|
539
|
+
method=method, rtol=rtol, atol=atol)
|
|
540
|
+
ghostSeq_ = ghostID(model, model_params, dt, sol.y.T,
|
|
541
|
+
epsilon_gid, delta_gid=delta_gid,peak_kwargs=peak_kwargs,
|
|
542
|
+
batchModel=batchModel,return_ctrl_figs=return_ctrl_figs,ctrlOutputs=ctrlOutputs,
|
|
543
|
+
evLimit=evLimit,slopeLimits=slopeLimits,eigval_NN_sorting=eigval_NN_sorting,
|
|
544
|
+
ev_outlier_removal_ws=ev_outlier_removal_ws,ev_outlier_removal_k=ev_outlier_removal_k,
|
|
545
|
+
display_warnings=display_warnings)
|
|
546
|
+
if return_ctrl_figs == False:
|
|
547
|
+
ghostSeq = ghostSeq_
|
|
548
|
+
else:
|
|
549
|
+
ghostSeq, ctrl_figures = ghostSeq_ #ignore ctrl_figures for now
|
|
550
|
+
return ghostSeq if ghostSeq else None
|
|
551
|
+
|
|
552
|
+
# ---- Parallel execution ----
|
|
553
|
+
ghostSeqs = []
|
|
554
|
+
with Executor(max_workers=n_workers) as executor:
|
|
555
|
+
futures = [executor.submit(process_ic, ic) for ic in ICs]
|
|
556
|
+
for f in tqdm(as_completed(futures), total=len(futures),
|
|
557
|
+
desc="Processing ICs", unit="IC"):
|
|
558
|
+
res = f.result()
|
|
559
|
+
if res is not None:
|
|
560
|
+
ghostSeqs.append(res)
|
|
561
|
+
|
|
562
|
+
# ---- Unify results ----
|
|
563
|
+
if len(ghostSeqs) > 0:
|
|
564
|
+
ghostSeqs_unified = unify_IDs(ghostSeqs, epsilon_unify, False)
|
|
565
|
+
return ghostSeqs_unified
|
|
566
|
+
else:
|
|
567
|
+
if display_warnings:
|
|
568
|
+
print("ghostID_phaseSpaceSample: No ghosts found in any trajectory.")
|
|
569
|
+
return []
|
|
570
|
+
|
|
571
|
+
def make_batch_model(model, params):
|
|
572
|
+
"""
|
|
573
|
+
Wrap a single-point model into a batch version using vmap.
|
|
574
|
+
|
|
575
|
+
model: function (t, z, params) -> dz/dt
|
|
576
|
+
params: model parameters (passed unchanged)
|
|
577
|
+
|
|
578
|
+
Returns: function (Zs, params) -> dZs/dt for batch input
|
|
579
|
+
where Zs has shape (num_points, n).
|
|
580
|
+
"""
|
|
581
|
+
def single(z):
|
|
582
|
+
return model(0, z, params) # ignore t (or pass if needed)
|
|
583
|
+
|
|
584
|
+
batched = jax.vmap(single)
|
|
585
|
+
return batched
|
|
586
|
+
|
|
587
|
+
def find_local_Qminimum(model, x0, model_params, delta, *, global_method="lhs", local_method="L-BFGS-B",
|
|
588
|
+
global_options=None, local_options=None, verbose=False):
|
|
589
|
+
|
|
590
|
+
"""
|
|
591
|
+
Search for a Q-minimum within a region of radius delta around a given point
|
|
592
|
+
x0 in phase space, where Q(x) = 0.5 * ||F(x)||^2 and F is the vector field
|
|
593
|
+
defined by model. The search combines a global strategy (either sampling or
|
|
594
|
+
a global optimisation algorithm) with an optional subsequent local
|
|
595
|
+
optimisation step.
|
|
596
|
+
|
|
597
|
+
Parameters
|
|
598
|
+
----------
|
|
599
|
+
model : callable
|
|
600
|
+
Python function describing the system dynamics.
|
|
601
|
+
x0 : array-like
|
|
602
|
+
Centre of the search region in phase space. The Q-minimum is sought
|
|
603
|
+
within a ball of radius delta around this point.
|
|
604
|
+
model_params : list or array-like
|
|
605
|
+
Parameters passed as arguments to model.
|
|
606
|
+
delta : float
|
|
607
|
+
Radius of the region around x0 in which to search for a Q-minimum.
|
|
608
|
+
global_method : str, optional
|
|
609
|
+
Strategy used for the global search stage. Options are:
|
|
610
|
+
|
|
611
|
+
'lhs' (default)
|
|
612
|
+
Draw a Latin-hypercube sample from the delta-ball around x0,
|
|
613
|
+
evaluate Q at each sample point, and pass the k_seeds points with
|
|
614
|
+
the lowest Q-values to the local optimiser as starting points.
|
|
615
|
+
Controlled via global_options (see below).
|
|
616
|
+
'differential_evolution'
|
|
617
|
+
Use scipy.optimize.differential_evolution to search for a
|
|
618
|
+
Q-minimum within the delta-ball.
|
|
619
|
+
'dual_annealing'
|
|
620
|
+
Use scipy.optimize.dual_annealing to search for a Q-minimum
|
|
621
|
+
within the delta-ball.
|
|
622
|
+
'basin_hopping'
|
|
623
|
+
Use scipy.optimize.basin_hopping to search for a Q-minimum
|
|
624
|
+
within the delta-ball.
|
|
625
|
+
|
|
626
|
+
local_method : str, optional
|
|
627
|
+
Local optimisation method applied after the global search stage.
|
|
628
|
+
Accepts any method available in scipy.optimize.minimize.
|
|
629
|
+
Default is 'L-BFGS-B'.
|
|
630
|
+
global_options : dict or None, optional
|
|
631
|
+
Options controlling the global search stage. For the scipy.optimize
|
|
632
|
+
global methods ('differential_evolution', 'dual_annealing',
|
|
633
|
+
'basin_hopping'), any keyword argument accepted by the corresponding
|
|
634
|
+
scipy.optimize function may be provided. For global_method='lhs',
|
|
635
|
+
the following keys are recognised:
|
|
636
|
+
|
|
637
|
+
n_samples : int or None
|
|
638
|
+
Size of the Latin-hypercube sample. If None, chosen automatically
|
|
639
|
+
as min(2000, max(200, 20 * dim)), where dim is the dimension of
|
|
640
|
+
the model.
|
|
641
|
+
k_seeds : int or None
|
|
642
|
+
Number of lowest-Q sample points passed to the local optimiser as
|
|
643
|
+
starting points. If None, chosen automatically as
|
|
644
|
+
min(5, max(2, int(sqrt(dim)))), where dim is the dimension of
|
|
645
|
+
the model.
|
|
646
|
+
seed : int or None
|
|
647
|
+
Random seed for the Latin-hypercube sampler. Setting a specific
|
|
648
|
+
value ensures exact reproducibility of the sample. Default is None.
|
|
649
|
+
|
|
650
|
+
Default for global_options is None (all sub-options use their defaults).
|
|
651
|
+
local_options : dict or None, optional
|
|
652
|
+
Options passed directly to scipy.optimize.minimize for the local
|
|
653
|
+
optimisation step. Accepts any keyword argument recognised by the
|
|
654
|
+
chosen local_method. Default is None.
|
|
655
|
+
verbose : bool, optional
|
|
656
|
+
If True, enables control outputs during the search. Default is False.
|
|
657
|
+
|
|
658
|
+
Returns
|
|
659
|
+
-------
|
|
660
|
+
result : OptimizeResult or similar
|
|
661
|
+
The identified Q-minimum within the delta-ball around x0.
|
|
662
|
+
"""
|
|
663
|
+
|
|
664
|
+
F = model
|
|
665
|
+
p = model_params
|
|
666
|
+
|
|
667
|
+
x0 = np.asarray(x0, dtype=float)
|
|
668
|
+
dim = x0.size
|
|
669
|
+
bounds = [(x0[i] - delta, x0[i] + delta) for i in range(dim)]
|
|
670
|
+
|
|
671
|
+
global_options = {} if global_options is None else dict(global_options)
|
|
672
|
+
local_options = {} if local_options is None else dict(local_options)
|
|
673
|
+
|
|
674
|
+
# --------------------------------------------------
|
|
675
|
+
# Define Q and grad Q
|
|
676
|
+
# --------------------------------------------------
|
|
677
|
+
def Q_func(x):
|
|
678
|
+
z = jnp.asarray(x)
|
|
679
|
+
return float(0.5 * jnp.sum(F(0.0, z, p) ** 2)) #SHOULD THERE NOT BE A SQRT HERE???
|
|
680
|
+
|
|
681
|
+
grad_Q = jax.grad(lambda z: 0.5 * jnp.sum(F(0.0, z, p) ** 2))
|
|
682
|
+
|
|
683
|
+
def grad_Q_np(x):
|
|
684
|
+
return np.asarray(grad_Q(jnp.asarray(x)))
|
|
685
|
+
|
|
686
|
+
# --------------------------------------------------
|
|
687
|
+
# Global search
|
|
688
|
+
# --------------------------------------------------
|
|
689
|
+
if verbose:
|
|
690
|
+
print(f"[Q-min] Global search method: {global_method}")
|
|
691
|
+
|
|
692
|
+
candidate_points = []
|
|
693
|
+
|
|
694
|
+
# ---- LHS ------------------------------------------------
|
|
695
|
+
if global_method == "lhs":
|
|
696
|
+
# Compute dimension-aware defaults if None
|
|
697
|
+
n_samples = global_options.get("n_samples", None)
|
|
698
|
+
if n_samples is None:
|
|
699
|
+
n_samples = min(2000, max(200, 20 * dim))
|
|
700
|
+
|
|
701
|
+
k_seeds = global_options.get("k_seeds", None)
|
|
702
|
+
if k_seeds is None:
|
|
703
|
+
k_seeds = min(5, max(2, int(np.sqrt(dim))))
|
|
704
|
+
|
|
705
|
+
seed = global_options.get("seed", None)
|
|
706
|
+
|
|
707
|
+
if verbose:
|
|
708
|
+
print(
|
|
709
|
+
f"[Q-min] LHS: n_samples={n_samples}, "
|
|
710
|
+
f"k_seeds={k_seeds}, seed={seed}"
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
sampler = qmc.LatinHypercube(d=dim, seed=seed)
|
|
714
|
+
samples = sampler.random(n=n_samples)
|
|
715
|
+
points = qmc.scale(
|
|
716
|
+
samples,
|
|
717
|
+
[b[0] for b in bounds],
|
|
718
|
+
[b[1] for b in bounds],
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
Q_vals = np.array([Q_func(x) for x in points])
|
|
722
|
+
best_idx = np.argsort(Q_vals)[:k_seeds]
|
|
723
|
+
candidate_points = points[best_idx]
|
|
724
|
+
|
|
725
|
+
# ---- Differential Evolution -----------------------------
|
|
726
|
+
elif global_method == "differential_evolution":
|
|
727
|
+
res = differential_evolution(
|
|
728
|
+
Q_func,
|
|
729
|
+
bounds,
|
|
730
|
+
disp=verbose,
|
|
731
|
+
**global_options,
|
|
732
|
+
)
|
|
733
|
+
candidate_points = [res.x]
|
|
734
|
+
|
|
735
|
+
# ---- Dual Annealing -------------------------------------
|
|
736
|
+
elif global_method == "dual_annealing":
|
|
737
|
+
res = dual_annealing(
|
|
738
|
+
Q_func,
|
|
739
|
+
bounds,
|
|
740
|
+
**global_options,
|
|
741
|
+
)
|
|
742
|
+
candidate_points = [res.x]
|
|
743
|
+
|
|
744
|
+
# ---- Basin Hopping --------------------------------------
|
|
745
|
+
elif global_method == "basin_hopping":
|
|
746
|
+
minimizer_kwargs = {
|
|
747
|
+
"method": local_method,
|
|
748
|
+
"jac": grad_Q_np,
|
|
749
|
+
"bounds": bounds if local_method in {"L-BFGS-B", "TNC", "SLSQP"} else None,
|
|
750
|
+
}
|
|
751
|
+
res = basinhopping(
|
|
752
|
+
Q_func,
|
|
753
|
+
x0,
|
|
754
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
755
|
+
**global_options,
|
|
756
|
+
)
|
|
757
|
+
candidate_points = [res.x]
|
|
758
|
+
|
|
759
|
+
else:
|
|
760
|
+
raise ValueError(f"Unknown global_method '{global_method}'")
|
|
761
|
+
|
|
762
|
+
# --------------------------------------------------
|
|
763
|
+
# Local refinement
|
|
764
|
+
# --------------------------------------------------
|
|
765
|
+
if local_method is None:
|
|
766
|
+
Q_vals = [Q_func(x) for x in candidate_points]
|
|
767
|
+
best = int(np.argmin(Q_vals))
|
|
768
|
+
return candidate_points[best], Q_vals[best], None
|
|
769
|
+
|
|
770
|
+
results_local = []
|
|
771
|
+
|
|
772
|
+
for x_start in candidate_points:
|
|
773
|
+
res = minimize(
|
|
774
|
+
Q_func,
|
|
775
|
+
x_start,
|
|
776
|
+
jac=grad_Q_np,
|
|
777
|
+
method=local_method,
|
|
778
|
+
bounds=bounds if local_method in {"L-BFGS-B", "TNC", "SLSQP"} else None,
|
|
779
|
+
options={
|
|
780
|
+
"disp": verbose,
|
|
781
|
+
**local_options,
|
|
782
|
+
},
|
|
783
|
+
)
|
|
784
|
+
results_local.append(res)
|
|
785
|
+
|
|
786
|
+
best_res = min(results_local, key=lambda r: r.fun)
|
|
787
|
+
|
|
788
|
+
if verbose:
|
|
789
|
+
print(f"[Q-min] Final Q = {best_res.fun:.3e}")
|
|
790
|
+
print(f"[Q-min] x = {best_res.x}")
|
|
791
|
+
|
|
792
|
+
return best_res.x, best_res.fun, best_res
|
|
793
|
+
|
|
794
|
+
def qOnGrid(model, model_params, coords=None, n_points=50, ranges=None, overrides=None, indexing="ij", jit=False):
|
|
795
|
+
|
|
796
|
+
"""
|
|
797
|
+
Evaluate Q(x) = 0.5 * ||F(x)||^2 on a phase space grid, where F is the
|
|
798
|
+
vector field defined by the model.
|
|
799
|
+
|
|
800
|
+
Parameters
|
|
801
|
+
----------
|
|
802
|
+
model : callable
|
|
803
|
+
Python function describing the system dynamics.
|
|
804
|
+
model_params : list or array-like
|
|
805
|
+
Parameters passed as arguments to model.
|
|
806
|
+
coords : list of 1D array-like or None, optional
|
|
807
|
+
Explicit phase space grid on which to evaluate Q-values, provided as a
|
|
808
|
+
list of 1D arrays (one per dimension). If None, the grid is constructed
|
|
809
|
+
automatically using n_points, ranges, and overrides (see below).
|
|
810
|
+
Default is None.
|
|
811
|
+
n_points : int or list of int, optional
|
|
812
|
+
Number of grid points along each dimension. A single integer applies
|
|
813
|
+
the same value to all dimensions; a list of integers sets each
|
|
814
|
+
dimension individually. Ignored if coords is provided. Default is 50.
|
|
815
|
+
ranges : tuple or list of tuple, optional
|
|
816
|
+
Bounds of the grid along each dimension, given as a
|
|
817
|
+
single (lower, upper) tuple applied to all dimensions, or a list of
|
|
818
|
+
such tuples to set each dimension individually. Ignored if coords is
|
|
819
|
+
provided. Default is (-2, 2).
|
|
820
|
+
overrides : dict or None, optional
|
|
821
|
+
Dictionary for overriding n_points and/or ranges for specific
|
|
822
|
+
individual axes without affecting the others. Keys are axis indices
|
|
823
|
+
(integers) and values are dictionaries with keys 'n' (int) and/or
|
|
824
|
+
'range' (tuple). For example, to set axis 1 to 100 points over
|
|
825
|
+
(-5.0, 5.0): {1: {'n': 100, 'range': (-5.0, 5.0)}}. Ignored if
|
|
826
|
+
coords is provided. Default is None.
|
|
827
|
+
indexing : str, optional
|
|
828
|
+
Indexing convention passed to jax.numpy.meshgrid. Default is 'ij'
|
|
829
|
+
(matrix-style indexing, where the first index varies along rows).
|
|
830
|
+
jit : bool, optional
|
|
831
|
+
If True, enables JAX just-in-time (JIT) compilation to speed up
|
|
832
|
+
evaluation of Q on the grid. Default is False.
|
|
833
|
+
|
|
834
|
+
Returns
|
|
835
|
+
-------
|
|
836
|
+
Q_grid : array-like
|
|
837
|
+
Array of Q-values evaluated at each point of the phase space grid,
|
|
838
|
+
with shape determined by n_points and the number of dimensions.
|
|
839
|
+
"""
|
|
840
|
+
|
|
841
|
+
if coords is None:
|
|
842
|
+
|
|
843
|
+
test = model(0.0, jnp.zeros(1), model_params)
|
|
844
|
+
dim = test.shape[0]
|
|
845
|
+
|
|
846
|
+
if ranges is None:
|
|
847
|
+
ranges = [(-2.0, 2.0)]*dim
|
|
848
|
+
elif isinstance(ranges[0], (int,float)):
|
|
849
|
+
ranges = [ranges]*dim
|
|
850
|
+
|
|
851
|
+
if isinstance(n_points,int):
|
|
852
|
+
n_points = [n_points]*dim
|
|
853
|
+
|
|
854
|
+
coords = []
|
|
855
|
+
for d in range(dim):
|
|
856
|
+
n = n_points[d] if d<len(n_points) else n_points[-1]
|
|
857
|
+
r = ranges[d] if d<len(ranges) else ranges[-1]
|
|
858
|
+
if overrides and d in overrides:
|
|
859
|
+
if "n" in overrides[d]:
|
|
860
|
+
n = overrides[d]["n"]
|
|
861
|
+
if "range" in overrides[d]:
|
|
862
|
+
r = overrides[d]["range"]
|
|
863
|
+
coords.append(jnp.linspace(r[0], r[1], n))
|
|
864
|
+
|
|
865
|
+
meshes = jnp.meshgrid(*coords, indexing=indexing)
|
|
866
|
+
grid_points = jnp.stack(meshes, axis=-1)
|
|
867
|
+
|
|
868
|
+
def core(grid_points):
|
|
869
|
+
flat_pts = grid_points.reshape(-1, grid_points.shape[-1])
|
|
870
|
+
F_vmapped = jax.vmap(lambda pt: model(0.0, pt, model_params))
|
|
871
|
+
values = F_vmapped(flat_pts)
|
|
872
|
+
Q_flat = 0.5*jnp.sum(values**2, axis=-1)
|
|
873
|
+
return Q_flat.reshape(grid_points.shape[:-1])
|
|
874
|
+
|
|
875
|
+
core = jax.jit(core) if jit else core
|
|
876
|
+
return core(grid_points), grid_points
|
|
877
|
+
|
|
878
|
+
def track_ghost_branch(ghost, model, model_params, par_nr, par_steps, dpar, t_end, dt, delta=0.5, icStep=0.1, mode="first",
|
|
879
|
+
epsilon_gid=0.1,solve_ivp_method='RK45', rtol=1.e-3, atol=1.e-6,**kwargs):
|
|
880
|
+
"""
|
|
881
|
+
Track a ghost state across a parameter sweep by iteratively updating the parameter
|
|
882
|
+
and re-identifying the ghost at each step.
|
|
883
|
+
|
|
884
|
+
Parameters
|
|
885
|
+
----------
|
|
886
|
+
ghost : dict
|
|
887
|
+
Ghost to be tracked, identified either by 'ghostID' or 'ghostID_phaseSpaceSample'.
|
|
888
|
+
model : callable
|
|
889
|
+
Python function describing the system dynamics.
|
|
890
|
+
model_params : list or array-like
|
|
891
|
+
Parameters for the model function.
|
|
892
|
+
par_nr : int
|
|
893
|
+
Index of the parameter in model_params to be varied during the sweep.
|
|
894
|
+
par_steps : int
|
|
895
|
+
Number of times the parameter is updated during the sweep.
|
|
896
|
+
dpar : float
|
|
897
|
+
Size of the parameter increment per iteration. Use a positive value to increase
|
|
898
|
+
the parameter and a negative value to decrease it.
|
|
899
|
+
t_end : float
|
|
900
|
+
Length of the trajectory integrated at each iteration step.
|
|
901
|
+
dt : float
|
|
902
|
+
Step-size used for numerical integration.
|
|
903
|
+
delta : float, optional
|
|
904
|
+
Size of the region around the current ghost position (xg) in which to search
|
|
905
|
+
for xQmin (the point of minimum speed). Default is 0.5.
|
|
906
|
+
icStep : float, optional
|
|
907
|
+
Distance from xQmin at which to initialize a trajectory that will be analyzed
|
|
908
|
+
by GhostID. Default is 0.1.
|
|
909
|
+
mode : {'first', 'closest'}, optional
|
|
910
|
+
Strategy for selecting among multiple ghosts potentially identified along a
|
|
911
|
+
trajectory.
|
|
912
|
+
- 'first' : take the first ghost found (default).
|
|
913
|
+
- 'closest' : take the ghost closest in phase space to the current ghost
|
|
914
|
+
position (xg).
|
|
915
|
+
epsilon_gid : float, optional
|
|
916
|
+
Threshold parameter for GhostID (see section 1.1 of the paper). Default is 0.1.
|
|
917
|
+
solve_ivp_method : str, optional
|
|
918
|
+
Integration method passed to scipy.integrate.solve_ivp (e.g. 'RK45').
|
|
919
|
+
Default is 'RK45'.
|
|
920
|
+
rtol : float, optional
|
|
921
|
+
Relative tolerance for the numerical integrator. Default is 1e-3.
|
|
922
|
+
atol : float, optional
|
|
923
|
+
Absolute tolerance for the numerical integrator. Default is 1e-6.
|
|
924
|
+
**kwargs
|
|
925
|
+
Optional keyword arguments:
|
|
926
|
+
- distQminThr (float): Maximum allowable distance between the identified ghost
|
|
927
|
+
and xQmin. Any ghost candidate whose distance from xQmin exceeds this
|
|
928
|
+
threshold is rejected. Default is infinity (no constraint).
|
|
929
|
+
|
|
930
|
+
Returns
|
|
931
|
+
-------
|
|
932
|
+
ghostPositions : array-like
|
|
933
|
+
Positions in phase space of the ghosts found at each parameter step.
|
|
934
|
+
parSeq : array-like
|
|
935
|
+
Sequence of parameter values at which each ghost was identified.
|
|
936
|
+
ghostSeq_p : list
|
|
937
|
+
Full list of ghost objects found at each parameter value, enabling extraction
|
|
938
|
+
of additional ghost properties beyond phase-space position.
|
|
939
|
+
control_plots : optional
|
|
940
|
+
Control plots for GhostID, returned only if the corresponding option is enabled.
|
|
941
|
+
"""
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
# Parse and validate kwargs
|
|
945
|
+
config = parse_kwargs(**kwargs)
|
|
946
|
+
|
|
947
|
+
display_warnings = config['display_warnings']
|
|
948
|
+
|
|
949
|
+
# Extract parameters from config
|
|
950
|
+
delta_gid = config['delta_gid']
|
|
951
|
+
peak_kwargs = config['peak_kwargs']
|
|
952
|
+
if "width" not in peak_kwargs:
|
|
953
|
+
peak_kwargs["width"] = 5 * dt # default if not supplied
|
|
954
|
+
batchModel = config['batchModel']
|
|
955
|
+
eigval_NN_sorting = config['eigval_NN_sorting']
|
|
956
|
+
ev_outlier_removal_ws = config['ev_outlier_removal_ws']
|
|
957
|
+
ev_outlier_removal_k = config['ev_outlier_removal_k']
|
|
958
|
+
evLimit = config['evLimit']
|
|
959
|
+
slopeLimits = config['slopeLimits']
|
|
960
|
+
|
|
961
|
+
# Plotting control settings
|
|
962
|
+
return_ctrl_figs = config['return_ctrl_figs']
|
|
963
|
+
ctrlOutputs = kwargs.get("ctrlOutputs", {})
|
|
964
|
+
|
|
965
|
+
# Q-min search related kwargs
|
|
966
|
+
distQminThr = config['distQminThr']
|
|
967
|
+
qmin_glob_method = config['qmin_glob_method']
|
|
968
|
+
qmin_loc_method = config['qmin_loc_method']
|
|
969
|
+
qmin_glob_options = config['qmin_glob_options']
|
|
970
|
+
min_loc_options = config['qmin_loc_options']
|
|
971
|
+
|
|
972
|
+
ghostSeq_p = []
|
|
973
|
+
ctrl_figures_p = []
|
|
974
|
+
parSeq = []
|
|
975
|
+
|
|
976
|
+
parNext = model_params[par_nr]
|
|
977
|
+
try:
|
|
978
|
+
model_params_ = np.asarray(model_params).copy()
|
|
979
|
+
except:
|
|
980
|
+
model_params_ = model_params.copy()
|
|
981
|
+
|
|
982
|
+
ghost_ = ghost.copy()
|
|
983
|
+
|
|
984
|
+
i = 0
|
|
985
|
+
with tqdm(total=par_steps + 1) as pbar:
|
|
986
|
+
while i < par_steps+1:
|
|
987
|
+
pct = 100 * i / par_steps
|
|
988
|
+
pbar.set_description(f"Progress: {pct:6.2f}% | param value={parNext:.5f}")
|
|
989
|
+
|
|
990
|
+
x0 = ghost_["position"]
|
|
991
|
+
|
|
992
|
+
qmin = find_local_Qminimum(model,x0,model_params_,delta,
|
|
993
|
+
global_method=qmin_glob_method,
|
|
994
|
+
local_method=qmin_loc_method,
|
|
995
|
+
global_options=qmin_glob_options,
|
|
996
|
+
local_options=min_loc_options,
|
|
997
|
+
verbose=False)[0]
|
|
998
|
+
|
|
999
|
+
ic_plus, _ , _ = icAtQmin(qmin, icStep ,ghost_["dimension"],model,model_params_)
|
|
1000
|
+
ic_minus, _ , _ = icAtQmin(qmin, -icStep ,ghost_["dimension"],model,model_params_)
|
|
1001
|
+
|
|
1002
|
+
sol_plus = solve_ivp(model, (0,5*dt), jnp.real(ic_plus), t_eval=np.asarray(np.arange(0, 5*dt, dt)), rtol=rtol, atol=atol, args=([model_params_]), method=solve_ivp_method)
|
|
1003
|
+
sol_minus = solve_ivp(model, (0,5*dt), jnp.real(ic_minus), t_eval=np.asarray(np.arange(0, 5*dt, dt)), rtol=rtol, atol=atol, args=([model_params_]), method=solve_ivp_method)
|
|
1004
|
+
|
|
1005
|
+
dist_ic_plus = np.linalg.norm(qmin-ic_plus)
|
|
1006
|
+
dist_sol_plus = np.linalg.norm(qmin-sol_plus.y[:,-1])
|
|
1007
|
+
|
|
1008
|
+
dist_ic_minus = np.linalg.norm(qmin-ic_minus)
|
|
1009
|
+
dist_sol_minus= np.linalg.norm(qmin-sol_minus.y[:,-1])
|
|
1010
|
+
|
|
1011
|
+
if dist_sol_plus<dist_ic_plus:
|
|
1012
|
+
ic_pick = ic_plus
|
|
1013
|
+
elif dist_sol_minus<dist_ic_minus:
|
|
1014
|
+
ic_pick = ic_minus
|
|
1015
|
+
else:
|
|
1016
|
+
print("Terminating track_ghost_branch: Error in chosing initial conditions around qmin (both trajectories are diverging). Try different global/local method/options for finding qmin or different icStep size.")
|
|
1017
|
+
if return_ctrl_figs == False:
|
|
1018
|
+
return None, None, None
|
|
1019
|
+
else:
|
|
1020
|
+
return None, None, None, None
|
|
1021
|
+
|
|
1022
|
+
ic_pick = jnp.real(ic_pick)
|
|
1023
|
+
|
|
1024
|
+
sol = solve_ivp(model, (0,t_end), ic_pick, t_eval=np.asarray(np.arange(0, t_end, dt)), rtol=rtol, atol=atol, args=([model_params_]), method=solve_ivp_method)
|
|
1025
|
+
|
|
1026
|
+
gid_output = ghostID(model, model_params, dt, sol.y.T,
|
|
1027
|
+
epsilon_gid, delta_gid=delta_gid,peak_kwargs=peak_kwargs,
|
|
1028
|
+
batchModel=batchModel,return_ctrl_figs=return_ctrl_figs,ctrlOutputs=ctrlOutputs,
|
|
1029
|
+
evLimit=evLimit,slopeLimits=slopeLimits,eigval_NN_sorting=eigval_NN_sorting,
|
|
1030
|
+
ev_outlier_removal_ws=ev_outlier_removal_ws,ev_outlier_removal_k=ev_outlier_removal_k,
|
|
1031
|
+
display_warnings=display_warnings)
|
|
1032
|
+
|
|
1033
|
+
if return_ctrl_figs == False:
|
|
1034
|
+
ghostSeq = gid_output
|
|
1035
|
+
else:
|
|
1036
|
+
ghostSeq, ctrl_figures = gid_output
|
|
1037
|
+
ctrl_figures_p.append(ctrl_figures)
|
|
1038
|
+
|
|
1039
|
+
if len(ghostSeq)>0:
|
|
1040
|
+
|
|
1041
|
+
#append
|
|
1042
|
+
if mode=="first":
|
|
1043
|
+
distance = np.linalg.norm(ghostSeq[0]["position"]-qmin)
|
|
1044
|
+
if distance < distQminThr:
|
|
1045
|
+
ghostSeq_p.append(ghostSeq[0])
|
|
1046
|
+
parSeq.append(parNext)
|
|
1047
|
+
elif mode == "closest":
|
|
1048
|
+
positions = np.array([ghostSeq[ii]["position"] for ii in range(len(ghostSeq))])
|
|
1049
|
+
distances = np.linalg.norm(positions-qmin,axis=1)
|
|
1050
|
+
idx_min = np.argmin(distances)
|
|
1051
|
+
if distances[idx_min]<distQminThr:
|
|
1052
|
+
ghostSeq_p.append(ghostSeq[idx_min])
|
|
1053
|
+
parSeq.append(parNext)
|
|
1054
|
+
else:
|
|
1055
|
+
print("Unknown mode argument. Use default mode instead.")
|
|
1056
|
+
mode = "first"
|
|
1057
|
+
continue
|
|
1058
|
+
|
|
1059
|
+
#update
|
|
1060
|
+
parNext = parNext + dpar
|
|
1061
|
+
model_params_[par_nr] = parNext
|
|
1062
|
+
pbar.update(1)
|
|
1063
|
+
i+=1
|
|
1064
|
+
|
|
1065
|
+
else:
|
|
1066
|
+
print("No further ghosts found.")
|
|
1067
|
+
break
|
|
1068
|
+
|
|
1069
|
+
ghostPositions = np.asarray([ghostSeq_p[ii]["position"] for ii in range(len(ghostSeq_p))])
|
|
1070
|
+
|
|
1071
|
+
if return_ctrl_figs == False:
|
|
1072
|
+
return ghostPositions, np.asarray(parSeq), ghostSeq_p
|
|
1073
|
+
else:
|
|
1074
|
+
return ghostPositions, np.asarray(parSeq), ghostSeq_p, ctrl_figures_p
|
|
1075
|
+
|
|
1076
|
+
def ghost_connections(gSeqs):
|
|
1077
|
+
"""
|
|
1078
|
+
Takes list of ghost sequences and turns it into an adjacency matrix.
|
|
1079
|
+
|
|
1080
|
+
Input:
|
|
1081
|
+
- gSeq: list of ghost sequences that have been generated by ghostID
|
|
1082
|
+
Output:
|
|
1083
|
+
- adjM: adjecency matrix representing connections between identified ghosts in phase space
|
|
1084
|
+
- labels: labels of matrix rows/columns
|
|
1085
|
+
"""
|
|
1086
|
+
|
|
1087
|
+
labels = []
|
|
1088
|
+
|
|
1089
|
+
for s in gSeqs:
|
|
1090
|
+
for i in s:
|
|
1091
|
+
if i["id"][:1]=="G" and not i["id"] in labels:
|
|
1092
|
+
labels.append(i["id"])
|
|
1093
|
+
|
|
1094
|
+
ng = len(labels)
|
|
1095
|
+
adjM = np.zeros((ng,ng))
|
|
1096
|
+
|
|
1097
|
+
seqIDs = [[g["id"] for g in s] for s in gSeqs]
|
|
1098
|
+
|
|
1099
|
+
for s in seqIDs:
|
|
1100
|
+
for i in range(len(s)-1):
|
|
1101
|
+
e_out = labels.index(s[i])
|
|
1102
|
+
e_in = labels.index(s[i+1])
|
|
1103
|
+
if adjM[e_out,e_in]==0:
|
|
1104
|
+
adjM[e_out,e_in]=1
|
|
1105
|
+
|
|
1106
|
+
return adjM, labels
|
|
1107
|
+
|
|
1108
|
+
def unique_ghosts(gSeq):
|
|
1109
|
+
"""
|
|
1110
|
+
Takes list of unified ghost sequences and returns a list of all unique ghosts
|
|
1111
|
+
"""
|
|
1112
|
+
|
|
1113
|
+
ghostIDs = []
|
|
1114
|
+
ghostsUnique = []
|
|
1115
|
+
|
|
1116
|
+
for s in gSeq:
|
|
1117
|
+
for i in s:
|
|
1118
|
+
if i["id"][:1]=="G" and not i["id"] in ghostIDs:
|
|
1119
|
+
ghostIDs.append(i["id"])
|
|
1120
|
+
ghostsUnique.append(i)
|
|
1121
|
+
|
|
1122
|
+
return ghostsUnique
|
|
1123
|
+
|
|
1124
|
+
def unify_IDs(seqs, delta_unify=0.1, update=True):
|
|
1125
|
+
"""
|
|
1126
|
+
Unify ids across multiple sequences of transient ghost state objects found in timecourse simulations.
|
|
1127
|
+
|
|
1128
|
+
Parameters
|
|
1129
|
+
----------
|
|
1130
|
+
Seqs : list[list[dict]]
|
|
1131
|
+
Each inner list corresponds to one simulation run.
|
|
1132
|
+
Each dict represents a ghost state and must contain:
|
|
1133
|
+
- 'position' : np.ndarray, phase-space position of the ghost
|
|
1134
|
+
- 'id' : str of the form 'G{i}'
|
|
1135
|
+
- 'q-value' : float, scalar quality / stability measure
|
|
1136
|
+
- 'dimension' : int, dimension associated with the ghost
|
|
1137
|
+
|
|
1138
|
+
delta_gid : float, optional
|
|
1139
|
+
Distance threshold below which two ghost states are considered
|
|
1140
|
+
identical (i.e. the same ghost across runs).
|
|
1141
|
+
|
|
1142
|
+
update : bool, optional (default=True)
|
|
1143
|
+
If True, perform a second pass after ID unification that
|
|
1144
|
+
synchronizes properties (position, q-value, dimension)
|
|
1145
|
+
across all ghosts sharing the same ID.
|
|
1146
|
+
|
|
1147
|
+
Returns
|
|
1148
|
+
-------
|
|
1149
|
+
Seqs : list[list[dict]]
|
|
1150
|
+
The same list structure, with unified IDs and (optionally)
|
|
1151
|
+
updated ghost properties.
|
|
1152
|
+
"""
|
|
1153
|
+
|
|
1154
|
+
# ------------------------------------------------------------------
|
|
1155
|
+
# STEP 1: Initialize reference ghosts from the first sequence
|
|
1156
|
+
# ------------------------------------------------------------------
|
|
1157
|
+
|
|
1158
|
+
Seqs = seqs.copy()
|
|
1159
|
+
first = Seqs[0]
|
|
1160
|
+
|
|
1161
|
+
# Dictionary mapping ghost ID -> representative position
|
|
1162
|
+
known_G = {}
|
|
1163
|
+
|
|
1164
|
+
# Track the highest numerical ghost index encountered so far
|
|
1165
|
+
max_g = 0
|
|
1166
|
+
|
|
1167
|
+
for obj in first:
|
|
1168
|
+
pid = obj['id']
|
|
1169
|
+
if pid.startswith('G'):
|
|
1170
|
+
idx = int(pid[1:]) # extract numerical part of ID
|
|
1171
|
+
known_G[pid] = obj['position'].copy()
|
|
1172
|
+
max_g = max(max_g, idx)
|
|
1173
|
+
else:
|
|
1174
|
+
raise ValueError(f"Unrecognized id '{pid}'")
|
|
1175
|
+
|
|
1176
|
+
# ------------------------------------------------------------------
|
|
1177
|
+
# STEP 2: Unify IDs across all subsequent sequences
|
|
1178
|
+
# ------------------------------------------------------------------
|
|
1179
|
+
|
|
1180
|
+
for seq in Seqs[1:]:
|
|
1181
|
+
for obj in seq:
|
|
1182
|
+
pos = obj['position']
|
|
1183
|
+
orig = obj['id']
|
|
1184
|
+
|
|
1185
|
+
if not orig.startswith('G'):
|
|
1186
|
+
raise ValueError(f"Unrecognized id '{orig}'")
|
|
1187
|
+
|
|
1188
|
+
# Try to match this ghost against previously known ghosts
|
|
1189
|
+
matched = False
|
|
1190
|
+
for pid, refpos in known_G.items():
|
|
1191
|
+
# Compare Euclidean distance in phase space
|
|
1192
|
+
if np.linalg.norm(pos - refpos) < delta_unify:
|
|
1193
|
+
obj['id'] = pid # reuse existing ID
|
|
1194
|
+
matched = True
|
|
1195
|
+
break
|
|
1196
|
+
|
|
1197
|
+
# If no match was found, register a new ghost ID
|
|
1198
|
+
if not matched:
|
|
1199
|
+
max_g += 1
|
|
1200
|
+
new_id = f'G{max_g}'
|
|
1201
|
+
obj['id'] = new_id
|
|
1202
|
+
known_G[new_id] = pos.copy()
|
|
1203
|
+
|
|
1204
|
+
# ------------------------------------------------------------------
|
|
1205
|
+
# STEP 3 (optional): Update ghost properties across identical IDs
|
|
1206
|
+
# ------------------------------------------------------------------
|
|
1207
|
+
|
|
1208
|
+
if update:
|
|
1209
|
+
# Collect all ghosts grouped by ID
|
|
1210
|
+
ghosts_by_id = {}
|
|
1211
|
+
for seq in Seqs:
|
|
1212
|
+
for obj in seq:
|
|
1213
|
+
gid = obj['id']
|
|
1214
|
+
ghosts_by_id.setdefault(gid, []).append(obj)
|
|
1215
|
+
|
|
1216
|
+
# For each ghost ID, synchronize properties
|
|
1217
|
+
for gid, ghosts in ghosts_by_id.items():
|
|
1218
|
+
# ----------------------------------------------------------
|
|
1219
|
+
# (a) Find ghost with minimal q-value
|
|
1220
|
+
# ----------------------------------------------------------
|
|
1221
|
+
missing = [o for o in ghosts if 'q-value' not in o]
|
|
1222
|
+
if missing:
|
|
1223
|
+
print(f"Ghosts missing q-value for id {gid}:")
|
|
1224
|
+
for m in missing:
|
|
1225
|
+
print(m.keys())
|
|
1226
|
+
ref = min(ghosts, key=lambda o: o['q-value'])
|
|
1227
|
+
ref_pos = ref['position'].copy()
|
|
1228
|
+
ref_q = ref['q-value']
|
|
1229
|
+
|
|
1230
|
+
# ----------------------------------------------------------
|
|
1231
|
+
# (b) Find maximal dimension across ghosts with this ID
|
|
1232
|
+
# ----------------------------------------------------------
|
|
1233
|
+
max_dim = max(o['dimension'] for o in ghosts)
|
|
1234
|
+
|
|
1235
|
+
# ----------------------------------------------------------
|
|
1236
|
+
# (c) Update all ghosts with synchronized values
|
|
1237
|
+
# ----------------------------------------------------------
|
|
1238
|
+
for o in ghosts:
|
|
1239
|
+
o['position'] = ref_pos.copy()
|
|
1240
|
+
o['q-value'] = ref_q
|
|
1241
|
+
o['dimension'] = max_dim
|
|
1242
|
+
|
|
1243
|
+
return Seqs
|
|
1244
|
+
|
|
1245
|
+
def draw_network(adj_matrix, nodeCols, nlbls, layout="fdp", graphviz_args=None, layout_kwargs=None, rankdir="TB", node_size=1800, label_font_size=16.5, font="Arial"):
|
|
1246
|
+
"""
|
|
1247
|
+
layout options
|
|
1248
|
+
--------------
|
|
1249
|
+
Graphviz:
|
|
1250
|
+
'fdp', 'dot', 'neato', 'sfdp', 'circo'
|
|
1251
|
+
Semantic aliases:
|
|
1252
|
+
'hierarchical' -> Graphviz 'dot'
|
|
1253
|
+
NetworkX:
|
|
1254
|
+
any nx.*_layout function (NOT nx.draw_*)
|
|
1255
|
+
"""
|
|
1256
|
+
|
|
1257
|
+
if layout_kwargs is None:
|
|
1258
|
+
layout_kwargs = {}
|
|
1259
|
+
|
|
1260
|
+
nw_dim = adj_matrix.shape[0]
|
|
1261
|
+
G = nx.from_numpy_array(adj_matrix.transpose(), create_using=nx.DiGraph)
|
|
1262
|
+
|
|
1263
|
+
# --- Define edges ------------------------------------------------------
|
|
1264
|
+
inhEdges, actEdges = [], []
|
|
1265
|
+
for i in range(nw_dim):
|
|
1266
|
+
for ii in range(nw_dim):
|
|
1267
|
+
if adj_matrix[i, ii] == 1:
|
|
1268
|
+
G.add_edge(i, ii, weight=1)
|
|
1269
|
+
actEdges.append((i, ii))
|
|
1270
|
+
elif adj_matrix[i, ii] == -1:
|
|
1271
|
+
G.add_edge(i, ii, weight=-1)
|
|
1272
|
+
inhEdges.append((i, ii))
|
|
1273
|
+
|
|
1274
|
+
# --- Layout handling ---------------------------------------------------
|
|
1275
|
+
if graphviz_args is None:
|
|
1276
|
+
graphviz_args = (
|
|
1277
|
+
f"-Grankdir={rankdir} "
|
|
1278
|
+
"-Nwidth=350 -Nheight=350 -Nfixedsize=true "
|
|
1279
|
+
"-Goverlap=scale -Gnodesep=5000 -Granksep=200 "
|
|
1280
|
+
"-Nshape=oval -Nfontsize=14 -Econstraint=true"
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
if layout == "hierarchical":
|
|
1284
|
+
layout = "dot"
|
|
1285
|
+
|
|
1286
|
+
if isinstance(layout, str):
|
|
1287
|
+
pos = graphviz_layout(G, prog=layout, args=graphviz_args)
|
|
1288
|
+
|
|
1289
|
+
elif callable(layout):
|
|
1290
|
+
if layout.__name__.startswith("draw_"):
|
|
1291
|
+
raise ValueError(
|
|
1292
|
+
f"{layout.__name__} is a drawing function. "
|
|
1293
|
+
"Use a layout function like nx.spectral_layout instead."
|
|
1294
|
+
)
|
|
1295
|
+
pos = layout(G, **layout_kwargs)
|
|
1296
|
+
|
|
1297
|
+
else:
|
|
1298
|
+
raise ValueError("layout must be a string or a callable")
|
|
1299
|
+
|
|
1300
|
+
# --- Draw nodes --------------------------------------------------------
|
|
1301
|
+
node_opts = {
|
|
1302
|
+
"node_size": node_size,
|
|
1303
|
+
"edgecolors": "white",
|
|
1304
|
+
}
|
|
1305
|
+
nx.draw_networkx_nodes(
|
|
1306
|
+
G, pos, node_color=nodeCols, **node_opts, alpha=0.90
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
# --- Draw edges --------------------------------------------------------
|
|
1310
|
+
draw_custom_edges(
|
|
1311
|
+
G, pos, actEdges,
|
|
1312
|
+
color="k", head_length=8, head_width=4,
|
|
1313
|
+
width=1.1, trim_fraction=0.2
|
|
1314
|
+
)
|
|
1315
|
+
draw_custom_edges(
|
|
1316
|
+
G, pos, inhEdges,
|
|
1317
|
+
color="red", head_length=8, head_width=4,
|
|
1318
|
+
width=1.1, trim_fraction=0.2
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# --- Draw labels -------------------------------------------------------
|
|
1322
|
+
labels = {i: nlbls[i] for i in range(nw_dim)}
|
|
1323
|
+
nx.draw_networkx_labels(
|
|
1324
|
+
G,
|
|
1325
|
+
pos,
|
|
1326
|
+
labels,
|
|
1327
|
+
font_size=label_font_size,
|
|
1328
|
+
font_weight="bold",
|
|
1329
|
+
font_family=font,
|
|
1330
|
+
)
|